github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/dask/dask_runner.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """DaskRunner, executing remote jobs on Dask.distributed. 19 20 The DaskRunner is a runner implementation that executes a graph of 21 transformations across processes and workers via Dask distributed's 22 scheduler. 23 """ 24 import argparse 25 import dataclasses 26 import typing as t 27 28 from apache_beam import pvalue 29 from apache_beam.options.pipeline_options import PipelineOptions 30 from apache_beam.pipeline import AppliedPTransform 31 from apache_beam.pipeline import PipelineVisitor 32 from apache_beam.runners.dask.overrides import dask_overrides 33 from apache_beam.runners.dask.transform_evaluator import TRANSLATIONS 34 from apache_beam.runners.dask.transform_evaluator import NoOp 35 from apache_beam.runners.direct.direct_runner import BundleBasedDirectRunner 36 from apache_beam.runners.runner import PipelineResult 37 from apache_beam.runners.runner import PipelineState 38 from apache_beam.utils.interactive_utils import is_in_notebook 39 40 41 class DaskOptions(PipelineOptions): 42 @staticmethod 43 def _parse_timeout(candidate): 44 try: 45 return int(candidate) 46 except (TypeError, ValueError): 47 import dask 48 return dask.config.no_default 49 50 @classmethod 51 def _add_argparse_args(cls, parser: argparse.ArgumentParser) -> None: 52 parser.add_argument( 53 '--dask_client_address', 54 dest='address', 55 type=str, 56 default=None, 57 help='Address of a dask Scheduler server. Will default to a ' 58 '`dask.LocalCluster()`.') 59 parser.add_argument( 60 '--dask_connection_timeout', 61 dest='timeout', 62 type=DaskOptions._parse_timeout, 63 help='Timeout duration for initial connection to the scheduler.') 64 parser.add_argument( 65 '--dask_scheduler_file', 66 dest='scheduler_file', 67 type=str, 68 default=None, 69 help='Path to a file with scheduler information if available.') 70 # TODO(alxr): Add options for security. 71 parser.add_argument( 72 '--dask_client_name', 73 dest='name', 74 type=str, 75 default=None, 76 help='Gives the client a name that will be included in logs generated ' 77 'on the scheduler for matters relating to this client.') 78 parser.add_argument( 79 '--dask_connection_limit', 80 dest='connection_limit', 81 type=int, 82 default=512, 83 help='The number of open comms to maintain at once in the connection ' 84 'pool.') 85 86 87 @dataclasses.dataclass 88 class DaskRunnerResult(PipelineResult): 89 from dask import distributed 90 91 client: distributed.Client 92 futures: t.Sequence[distributed.Future] 93 94 def __post_init__(self): 95 super().__init__(PipelineState.RUNNING) 96 97 def wait_until_finish(self, duration=None) -> str: 98 try: 99 if duration is not None: 100 # Convert milliseconds to seconds 101 duration /= 1000 102 self.client.wait_for_workers(timeout=duration) 103 self.client.gather(self.futures, errors='raise') 104 self._state = PipelineState.DONE 105 except: # pylint: disable=broad-except 106 self._state = PipelineState.FAILED 107 raise 108 return self._state 109 110 def cancel(self) -> str: 111 self._state = PipelineState.CANCELLING 112 self.client.cancel(self.futures) 113 self._state = PipelineState.CANCELLED 114 return self._state 115 116 def metrics(self): 117 # TODO(alxr): Collect and return metrics... 118 raise NotImplementedError('collecting metrics will come later!') 119 120 121 class DaskRunner(BundleBasedDirectRunner): 122 """Executes a pipeline on a Dask distributed client.""" 123 @staticmethod 124 def to_dask_bag_visitor() -> PipelineVisitor: 125 from dask import bag as db 126 127 @dataclasses.dataclass 128 class DaskBagVisitor(PipelineVisitor): 129 bags: t.Dict[AppliedPTransform, 130 db.Bag] = dataclasses.field(default_factory=dict) 131 132 def visit_transform(self, transform_node: AppliedPTransform) -> None: 133 op_class = TRANSLATIONS.get(transform_node.transform.__class__, NoOp) 134 op = op_class(transform_node) 135 136 inputs = list(transform_node.inputs) 137 if inputs: 138 bag_inputs = [] 139 for input_value in inputs: 140 if isinstance(input_value, pvalue.PBegin): 141 bag_inputs.append(None) 142 143 prev_op = input_value.producer 144 if prev_op in self.bags: 145 bag_inputs.append(self.bags[prev_op]) 146 147 if len(bag_inputs) == 1: 148 self.bags[transform_node] = op.apply(bag_inputs[0]) 149 else: 150 self.bags[transform_node] = op.apply(bag_inputs) 151 152 else: 153 self.bags[transform_node] = op.apply(None) 154 155 return DaskBagVisitor() 156 157 @staticmethod 158 def is_fnapi_compatible(): 159 return False 160 161 def run_pipeline(self, pipeline, options): 162 # TODO(alxr): Create interactive notebook support. 163 if is_in_notebook(): 164 raise NotImplementedError('interactive support will come later!') 165 166 try: 167 import dask.distributed as ddist 168 except ImportError: 169 raise ImportError( 170 'DaskRunner is not available. Please install apache_beam[dask].') 171 172 dask_options = options.view_as(DaskOptions).get_all_options( 173 drop_default=True) 174 client = ddist.Client(**dask_options) 175 176 pipeline.replace_all(dask_overrides()) 177 178 dask_visitor = self.to_dask_bag_visitor() 179 pipeline.visit(dask_visitor) 180 181 futures = client.compute(list(dask_visitor.bags.values())) 182 return DaskRunnerResult(client, futures)