github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/direct/helper_transforms.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 # pytype: skip-file 19 20 import collections 21 import itertools 22 import typing 23 24 import apache_beam as beam 25 from apache_beam import typehints 26 from apache_beam.internal.util import ArgumentPlaceholder 27 from apache_beam.transforms.combiners import _CurriedFn 28 from apache_beam.utils.windowed_value import WindowedValue 29 30 31 class LiftedCombinePerKey(beam.PTransform): 32 """An implementation of CombinePerKey that does mapper-side pre-combining. 33 """ 34 def __init__(self, combine_fn, args, kwargs): 35 args_to_check = itertools.chain(args, kwargs.values()) 36 if isinstance(combine_fn, _CurriedFn): 37 args_to_check = itertools.chain( 38 args_to_check, combine_fn.args, combine_fn.kwargs.values()) 39 if any(isinstance(arg, ArgumentPlaceholder) for arg in args_to_check): 40 # This isn't implemented in dataflow either... 41 raise NotImplementedError('Deferred CombineFn side inputs.') 42 self._combine_fn = beam.transforms.combiners.curry_combine_fn( 43 combine_fn, args, kwargs) 44 45 def expand(self, pcoll): 46 return ( 47 pcoll 48 | beam.ParDo(PartialGroupByKeyCombiningValues(self._combine_fn)) 49 | beam.GroupByKey() 50 | beam.ParDo(FinishCombine(self._combine_fn))) 51 52 53 class PartialGroupByKeyCombiningValues(beam.DoFn): 54 """Aggregates values into a per-key-window cache. 55 56 As bundles are in-memory-sized, we don't bother flushing until the very end. 57 """ 58 def __init__(self, combine_fn): 59 self._combine_fn = combine_fn 60 61 def setup(self): 62 self._combine_fn.setup() 63 64 def start_bundle(self): 65 self._cache = collections.defaultdict(self._combine_fn.create_accumulator) 66 67 def process(self, element, window=beam.DoFn.WindowParam): 68 k, vi = element 69 self._cache[k, window] = self._combine_fn.add_input( 70 self._cache[k, window], vi) 71 72 def finish_bundle(self): 73 for (k, w), va in self._cache.items(): 74 # We compact the accumulator since a GBK (which necessitates encoding) 75 # will follow. 76 yield WindowedValue((k, self._combine_fn.compact(va)), w.end, (w, )) 77 78 def teardown(self): 79 self._combine_fn.teardown() 80 81 def default_type_hints(self): 82 hints = self._combine_fn.get_type_hints() 83 K = typehints.TypeVariable('K') 84 if hints.input_types: 85 args, kwargs = hints.input_types 86 args = (typehints.Tuple[K, args[0]], ) + args[1:] 87 hints = hints.with_input_types(*args, **kwargs) 88 else: 89 hints = hints.with_input_types(typehints.Tuple[K, typing.Any]) 90 hints = hints.with_output_types(typehints.Tuple[K, typing.Any]) 91 return hints 92 93 94 class FinishCombine(beam.DoFn): 95 """Merges partially combined results. 96 """ 97 def __init__(self, combine_fn): 98 self._combine_fn = combine_fn 99 100 def setup(self): 101 self._combine_fn.setup() 102 103 def process(self, element): 104 k, vs = element 105 return [( 106 k, 107 self._combine_fn.extract_output( 108 self._combine_fn.merge_accumulators(vs)))] 109 110 def teardown(self): 111 self._combine_fn.teardown() 112 113 def default_type_hints(self): 114 hints = self._combine_fn.get_type_hints() 115 K = typehints.TypeVariable('K') 116 hints = hints.with_input_types(typehints.Tuple[K, typing.Any]) 117 if hints.output_types: 118 main_output_type = hints.simple_output_type('') 119 hints = hints.with_output_types(typehints.Tuple[K, main_output_type]) 120 return hints