github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/direct/helper_transforms.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  # pytype: skip-file
    19  
    20  import collections
    21  import itertools
    22  import typing
    23  
    24  import apache_beam as beam
    25  from apache_beam import typehints
    26  from apache_beam.internal.util import ArgumentPlaceholder
    27  from apache_beam.transforms.combiners import _CurriedFn
    28  from apache_beam.utils.windowed_value import WindowedValue
    29  
    30  
    31  class LiftedCombinePerKey(beam.PTransform):
    32    """An implementation of CombinePerKey that does mapper-side pre-combining.
    33    """
    34    def __init__(self, combine_fn, args, kwargs):
    35      args_to_check = itertools.chain(args, kwargs.values())
    36      if isinstance(combine_fn, _CurriedFn):
    37        args_to_check = itertools.chain(
    38            args_to_check, combine_fn.args, combine_fn.kwargs.values())
    39      if any(isinstance(arg, ArgumentPlaceholder) for arg in args_to_check):
    40        # This isn't implemented in dataflow either...
    41        raise NotImplementedError('Deferred CombineFn side inputs.')
    42      self._combine_fn = beam.transforms.combiners.curry_combine_fn(
    43          combine_fn, args, kwargs)
    44  
    45    def expand(self, pcoll):
    46      return (
    47          pcoll
    48          | beam.ParDo(PartialGroupByKeyCombiningValues(self._combine_fn))
    49          | beam.GroupByKey()
    50          | beam.ParDo(FinishCombine(self._combine_fn)))
    51  
    52  
    53  class PartialGroupByKeyCombiningValues(beam.DoFn):
    54    """Aggregates values into a per-key-window cache.
    55  
    56    As bundles are in-memory-sized, we don't bother flushing until the very end.
    57    """
    58    def __init__(self, combine_fn):
    59      self._combine_fn = combine_fn
    60  
    61    def setup(self):
    62      self._combine_fn.setup()
    63  
    64    def start_bundle(self):
    65      self._cache = collections.defaultdict(self._combine_fn.create_accumulator)
    66  
    67    def process(self, element, window=beam.DoFn.WindowParam):
    68      k, vi = element
    69      self._cache[k, window] = self._combine_fn.add_input(
    70          self._cache[k, window], vi)
    71  
    72    def finish_bundle(self):
    73      for (k, w), va in self._cache.items():
    74        # We compact the accumulator since a GBK (which necessitates encoding)
    75        # will follow.
    76        yield WindowedValue((k, self._combine_fn.compact(va)), w.end, (w, ))
    77  
    78    def teardown(self):
    79      self._combine_fn.teardown()
    80  
    81    def default_type_hints(self):
    82      hints = self._combine_fn.get_type_hints()
    83      K = typehints.TypeVariable('K')
    84      if hints.input_types:
    85        args, kwargs = hints.input_types
    86        args = (typehints.Tuple[K, args[0]], ) + args[1:]
    87        hints = hints.with_input_types(*args, **kwargs)
    88      else:
    89        hints = hints.with_input_types(typehints.Tuple[K, typing.Any])
    90      hints = hints.with_output_types(typehints.Tuple[K, typing.Any])
    91      return hints
    92  
    93  
    94  class FinishCombine(beam.DoFn):
    95    """Merges partially combined results.
    96    """
    97    def __init__(self, combine_fn):
    98      self._combine_fn = combine_fn
    99  
   100    def setup(self):
   101      self._combine_fn.setup()
   102  
   103    def process(self, element):
   104      k, vs = element
   105      return [(
   106          k,
   107          self._combine_fn.extract_output(
   108              self._combine_fn.merge_accumulators(vs)))]
   109  
   110    def teardown(self):
   111      self._combine_fn.teardown()
   112  
   113    def default_type_hints(self):
   114      hints = self._combine_fn.get_type_hints()
   115      K = typehints.TypeVariable('K')
   116      hints = hints.with_input_types(typehints.Tuple[K, typing.Any])
   117      if hints.output_types:
   118        main_output_type = hints.simple_output_type('')
   119        hints = hints.with_output_types(typehints.Tuple[K, main_output_type])
   120      return hints