github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/dask/transform_evaluator.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Transform Beam PTransforms into Dask Bag operations.
    19  
    20  A minimum set of operation substitutions, to adap Beam's PTransform model
    21  to Dask Bag functions.
    22  
    23  TODO(alxr): Translate ops from https://docs.dask.org/en/latest/bag-api.html.
    24  """
    25  import abc
    26  import dataclasses
    27  import typing as t
    28  
    29  import apache_beam
    30  import dask.bag as db
    31  from apache_beam.pipeline import AppliedPTransform
    32  from apache_beam.runners.dask.overrides import _Create
    33  from apache_beam.runners.dask.overrides import _Flatten
    34  from apache_beam.runners.dask.overrides import _GroupByKeyOnly
    35  
    36  OpInput = t.Union[db.Bag, t.Sequence[db.Bag], None]
    37  
    38  
    39  @dataclasses.dataclass
    40  class DaskBagOp(abc.ABC):
    41    applied: AppliedPTransform
    42  
    43    @property
    44    def transform(self):
    45      return self.applied.transform
    46  
    47    @abc.abstractmethod
    48    def apply(self, input_bag: OpInput) -> db.Bag:
    49      pass
    50  
    51  
    52  class NoOp(DaskBagOp):
    53    def apply(self, input_bag: OpInput) -> db.Bag:
    54      return input_bag
    55  
    56  
    57  class Create(DaskBagOp):
    58    def apply(self, input_bag: OpInput) -> db.Bag:
    59      assert input_bag is None, 'Create expects no input!'
    60      original_transform = t.cast(_Create, self.transform)
    61      items = original_transform.values
    62      return db.from_sequence(items)
    63  
    64  
    65  class ParDo(DaskBagOp):
    66    def apply(self, input_bag: db.Bag) -> db.Bag:
    67      transform = t.cast(apache_beam.ParDo, self.transform)
    68      return input_bag.map(
    69          transform.fn.process, *transform.args, **transform.kwargs).flatten()
    70  
    71  
    72  class Map(DaskBagOp):
    73    def apply(self, input_bag: db.Bag) -> db.Bag:
    74      transform = t.cast(apache_beam.Map, self.transform)
    75      return input_bag.map(
    76          transform.fn.process, *transform.args, **transform.kwargs)
    77  
    78  
    79  class GroupByKey(DaskBagOp):
    80    def apply(self, input_bag: db.Bag) -> db.Bag:
    81      def key(item):
    82        return item[0]
    83  
    84      def value(item):
    85        k, v = item
    86        return k, [elm[1] for elm in v]
    87  
    88      return input_bag.groupby(key).map(value)
    89  
    90  
    91  class Flatten(DaskBagOp):
    92    def apply(self, input_bag: OpInput) -> db.Bag:
    93      assert type(input_bag) is list, 'Must take a sequence of bags!'
    94      return db.concat(input_bag)
    95  
    96  
    97  TRANSLATIONS = {
    98      _Create: Create,
    99      apache_beam.ParDo: ParDo,
   100      apache_beam.Map: Map,
   101      _GroupByKeyOnly: GroupByKey,
   102      _Flatten: Flatten,
   103  }