github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/dask/overrides.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  import dataclasses
    18  import typing as t
    19  
    20  import apache_beam as beam
    21  from apache_beam import typehints
    22  from apache_beam.io.iobase import SourceBase
    23  from apache_beam.pipeline import AppliedPTransform
    24  from apache_beam.pipeline import PTransformOverride
    25  from apache_beam.runners.direct.direct_runner import _GroupAlsoByWindowDoFn
    26  from apache_beam.transforms import ptransform
    27  from apache_beam.transforms.window import GlobalWindows
    28  
    29  K = t.TypeVar("K")
    30  V = t.TypeVar("V")
    31  
    32  
    33  @dataclasses.dataclass
    34  class _Create(beam.PTransform):
    35    values: t.Tuple[t.Any]
    36  
    37    def expand(self, input_or_inputs):
    38      return beam.pvalue.PCollection.from_(input_or_inputs)
    39  
    40    def get_windowing(self, inputs: t.Any) -> beam.Windowing:
    41      return beam.Windowing(GlobalWindows())
    42  
    43  
    44  @typehints.with_input_types(K)
    45  @typehints.with_output_types(K)
    46  class _Reshuffle(beam.PTransform):
    47    def expand(self, input_or_inputs):
    48      return beam.pvalue.PCollection.from_(input_or_inputs)
    49  
    50  
    51  @dataclasses.dataclass
    52  class _Read(beam.PTransform):
    53    source: SourceBase
    54  
    55    def expand(self, input_or_inputs):
    56      return beam.pvalue.PCollection.from_(input_or_inputs)
    57  
    58  
    59  @typehints.with_input_types(t.Tuple[K, V])
    60  @typehints.with_output_types(t.Tuple[K, t.Iterable[V]])
    61  class _GroupByKeyOnly(beam.PTransform):
    62    def expand(self, input_or_inputs):
    63      return beam.pvalue.PCollection.from_(input_or_inputs)
    64  
    65    def infer_output_type(self, input_type):
    66  
    67      key_type, value_type = typehints.trivial_inference.key_value_types(
    68        input_type
    69      )
    70      return typehints.KV[key_type, typehints.Iterable[value_type]]
    71  
    72  
    73  @typehints.with_input_types(t.Tuple[K, t.Iterable[V]])
    74  @typehints.with_output_types(t.Tuple[K, t.Iterable[V]])
    75  class _GroupAlsoByWindow(beam.ParDo):
    76    """Not used yet..."""
    77    def __init__(self, windowing):
    78      super().__init__(_GroupAlsoByWindowDoFn(windowing))
    79      self.windowing = windowing
    80  
    81    def expand(self, input_or_inputs):
    82      return beam.pvalue.PCollection.from_(input_or_inputs)
    83  
    84  
    85  @typehints.with_input_types(t.Tuple[K, V])
    86  @typehints.with_output_types(t.Tuple[K, t.Iterable[V]])
    87  class _GroupByKey(beam.PTransform):
    88    def expand(self, input_or_inputs):
    89      return input_or_inputs | "GroupByKey" >> _GroupByKeyOnly()
    90  
    91  
    92  class _Flatten(beam.PTransform):
    93    def expand(self, input_or_inputs):
    94      is_bounded = all(pcoll.is_bounded for pcoll in input_or_inputs)
    95      return beam.pvalue.PCollection(self.pipeline, is_bounded=is_bounded)
    96  
    97  
    98  def dask_overrides() -> t.List[PTransformOverride]:
    99    class CreateOverride(PTransformOverride):
   100      def matches(self, applied_ptransform: AppliedPTransform) -> bool:
   101        return applied_ptransform.transform.__class__ == beam.Create
   102  
   103      def get_replacement_transform_for_applied_ptransform(
   104          self, applied_ptransform: AppliedPTransform) -> ptransform.PTransform:
   105        return _Create(t.cast(beam.Create, applied_ptransform.transform).values)
   106  
   107    class ReshuffleOverride(PTransformOverride):
   108      def matches(self, applied_ptransform: AppliedPTransform) -> bool:
   109        return applied_ptransform.transform.__class__ == beam.Reshuffle
   110  
   111      def get_replacement_transform_for_applied_ptransform(
   112          self, applied_ptransform: AppliedPTransform) -> ptransform.PTransform:
   113        return _Reshuffle()
   114  
   115    class ReadOverride(PTransformOverride):
   116      def matches(self, applied_ptransform: AppliedPTransform) -> bool:
   117        return applied_ptransform.transform.__class__ == beam.io.Read
   118  
   119      def get_replacement_transform_for_applied_ptransform(
   120          self, applied_ptransform: AppliedPTransform) -> ptransform.PTransform:
   121        return _Read(t.cast(beam.io.Read, applied_ptransform.transform).source)
   122  
   123    class GroupByKeyOverride(PTransformOverride):
   124      def matches(self, applied_ptransform: AppliedPTransform) -> bool:
   125        return applied_ptransform.transform.__class__ == beam.GroupByKey
   126  
   127      def get_replacement_transform_for_applied_ptransform(
   128          self, applied_ptransform: AppliedPTransform) -> ptransform.PTransform:
   129        return _GroupByKey()
   130  
   131    class FlattenOverride(PTransformOverride):
   132      def matches(self, applied_ptransform: AppliedPTransform) -> bool:
   133        return applied_ptransform.transform.__class__ == beam.Flatten
   134  
   135      def get_replacement_transform_for_applied_ptransform(
   136          self, applied_ptransform: AppliedPTransform) -> ptransform.PTransform:
   137        return _Flatten()
   138  
   139    return [
   140        CreateOverride(),
   141        ReshuffleOverride(),
   142        ReadOverride(),
   143        GroupByKeyOverride(),
   144        FlattenOverride(),
   145    ]