github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/dask/transform_evaluator.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Transform Beam PTransforms into Dask Bag operations. 19 20 A minimum set of operation substitutions, to adap Beam's PTransform model 21 to Dask Bag functions. 22 23 TODO(alxr): Translate ops from https://docs.dask.org/en/latest/bag-api.html. 24 """ 25 import abc 26 import dataclasses 27 import typing as t 28 29 import apache_beam 30 import dask.bag as db 31 from apache_beam.pipeline import AppliedPTransform 32 from apache_beam.runners.dask.overrides import _Create 33 from apache_beam.runners.dask.overrides import _Flatten 34 from apache_beam.runners.dask.overrides import _GroupByKeyOnly 35 36 OpInput = t.Union[db.Bag, t.Sequence[db.Bag], None] 37 38 39 @dataclasses.dataclass 40 class DaskBagOp(abc.ABC): 41 applied: AppliedPTransform 42 43 @property 44 def transform(self): 45 return self.applied.transform 46 47 @abc.abstractmethod 48 def apply(self, input_bag: OpInput) -> db.Bag: 49 pass 50 51 52 class NoOp(DaskBagOp): 53 def apply(self, input_bag: OpInput) -> db.Bag: 54 return input_bag 55 56 57 class Create(DaskBagOp): 58 def apply(self, input_bag: OpInput) -> db.Bag: 59 assert input_bag is None, 'Create expects no input!' 60 original_transform = t.cast(_Create, self.transform) 61 items = original_transform.values 62 return db.from_sequence(items) 63 64 65 class ParDo(DaskBagOp): 66 def apply(self, input_bag: db.Bag) -> db.Bag: 67 transform = t.cast(apache_beam.ParDo, self.transform) 68 return input_bag.map( 69 transform.fn.process, *transform.args, **transform.kwargs).flatten() 70 71 72 class Map(DaskBagOp): 73 def apply(self, input_bag: db.Bag) -> db.Bag: 74 transform = t.cast(apache_beam.Map, self.transform) 75 return input_bag.map( 76 transform.fn.process, *transform.args, **transform.kwargs) 77 78 79 class GroupByKey(DaskBagOp): 80 def apply(self, input_bag: db.Bag) -> db.Bag: 81 def key(item): 82 return item[0] 83 84 def value(item): 85 k, v = item 86 return k, [elm[1] for elm in v] 87 88 return input_bag.groupby(key).map(value) 89 90 91 class Flatten(DaskBagOp): 92 def apply(self, input_bag: OpInput) -> db.Bag: 93 assert type(input_bag) is list, 'Must take a sequence of bags!' 94 return db.concat(input_bag) 95 96 97 TRANSLATIONS = { 98 _Create: Create, 99 apache_beam.ParDo: ParDo, 100 apache_beam.Map: Map, 101 _GroupByKeyOnly: GroupByKey, 102 _Flatten: Flatten, 103 }