github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/dask/overrides.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 import dataclasses 18 import typing as t 19 20 import apache_beam as beam 21 from apache_beam import typehints 22 from apache_beam.io.iobase import SourceBase 23 from apache_beam.pipeline import AppliedPTransform 24 from apache_beam.pipeline import PTransformOverride 25 from apache_beam.runners.direct.direct_runner import _GroupAlsoByWindowDoFn 26 from apache_beam.transforms import ptransform 27 from apache_beam.transforms.window import GlobalWindows 28 29 K = t.TypeVar("K") 30 V = t.TypeVar("V") 31 32 33 @dataclasses.dataclass 34 class _Create(beam.PTransform): 35 values: t.Tuple[t.Any] 36 37 def expand(self, input_or_inputs): 38 return beam.pvalue.PCollection.from_(input_or_inputs) 39 40 def get_windowing(self, inputs: t.Any) -> beam.Windowing: 41 return beam.Windowing(GlobalWindows()) 42 43 44 @typehints.with_input_types(K) 45 @typehints.with_output_types(K) 46 class _Reshuffle(beam.PTransform): 47 def expand(self, input_or_inputs): 48 return beam.pvalue.PCollection.from_(input_or_inputs) 49 50 51 @dataclasses.dataclass 52 class _Read(beam.PTransform): 53 source: SourceBase 54 55 def expand(self, input_or_inputs): 56 return beam.pvalue.PCollection.from_(input_or_inputs) 57 58 59 @typehints.with_input_types(t.Tuple[K, V]) 60 @typehints.with_output_types(t.Tuple[K, t.Iterable[V]]) 61 class _GroupByKeyOnly(beam.PTransform): 62 def expand(self, input_or_inputs): 63 return beam.pvalue.PCollection.from_(input_or_inputs) 64 65 def infer_output_type(self, input_type): 66 67 key_type, value_type = typehints.trivial_inference.key_value_types( 68 input_type 69 ) 70 return typehints.KV[key_type, typehints.Iterable[value_type]] 71 72 73 @typehints.with_input_types(t.Tuple[K, t.Iterable[V]]) 74 @typehints.with_output_types(t.Tuple[K, t.Iterable[V]]) 75 class _GroupAlsoByWindow(beam.ParDo): 76 """Not used yet...""" 77 def __init__(self, windowing): 78 super().__init__(_GroupAlsoByWindowDoFn(windowing)) 79 self.windowing = windowing 80 81 def expand(self, input_or_inputs): 82 return beam.pvalue.PCollection.from_(input_or_inputs) 83 84 85 @typehints.with_input_types(t.Tuple[K, V]) 86 @typehints.with_output_types(t.Tuple[K, t.Iterable[V]]) 87 class _GroupByKey(beam.PTransform): 88 def expand(self, input_or_inputs): 89 return input_or_inputs | "GroupByKey" >> _GroupByKeyOnly() 90 91 92 class _Flatten(beam.PTransform): 93 def expand(self, input_or_inputs): 94 is_bounded = all(pcoll.is_bounded for pcoll in input_or_inputs) 95 return beam.pvalue.PCollection(self.pipeline, is_bounded=is_bounded) 96 97 98 def dask_overrides() -> t.List[PTransformOverride]: 99 class CreateOverride(PTransformOverride): 100 def matches(self, applied_ptransform: AppliedPTransform) -> bool: 101 return applied_ptransform.transform.__class__ == beam.Create 102 103 def get_replacement_transform_for_applied_ptransform( 104 self, applied_ptransform: AppliedPTransform) -> ptransform.PTransform: 105 return _Create(t.cast(beam.Create, applied_ptransform.transform).values) 106 107 class ReshuffleOverride(PTransformOverride): 108 def matches(self, applied_ptransform: AppliedPTransform) -> bool: 109 return applied_ptransform.transform.__class__ == beam.Reshuffle 110 111 def get_replacement_transform_for_applied_ptransform( 112 self, applied_ptransform: AppliedPTransform) -> ptransform.PTransform: 113 return _Reshuffle() 114 115 class ReadOverride(PTransformOverride): 116 def matches(self, applied_ptransform: AppliedPTransform) -> bool: 117 return applied_ptransform.transform.__class__ == beam.io.Read 118 119 def get_replacement_transform_for_applied_ptransform( 120 self, applied_ptransform: AppliedPTransform) -> ptransform.PTransform: 121 return _Read(t.cast(beam.io.Read, applied_ptransform.transform).source) 122 123 class GroupByKeyOverride(PTransformOverride): 124 def matches(self, applied_ptransform: AppliedPTransform) -> bool: 125 return applied_ptransform.transform.__class__ == beam.GroupByKey 126 127 def get_replacement_transform_for_applied_ptransform( 128 self, applied_ptransform: AppliedPTransform) -> ptransform.PTransform: 129 return _GroupByKey() 130 131 class FlattenOverride(PTransformOverride): 132 def matches(self, applied_ptransform: AppliedPTransform) -> bool: 133 return applied_ptransform.transform.__class__ == beam.Flatten 134 135 def get_replacement_transform_for_applied_ptransform( 136 self, applied_ptransform: AppliedPTransform) -> ptransform.PTransform: 137 return _Flatten() 138 139 return [ 140 CreateOverride(), 141 ReshuffleOverride(), 142 ReadOverride(), 143 GroupByKeyOverride(), 144 FlattenOverride(), 145 ]