github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/transforms/combiners.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """A library of basic combiner PTransform subclasses."""
    19  
    20  # pytype: skip-file
    21  
    22  import copy
    23  import heapq
    24  import itertools
    25  import operator
    26  import random
    27  from typing import Any
    28  from typing import Dict
    29  from typing import Iterable
    30  from typing import List
    31  from typing import Set
    32  from typing import Tuple
    33  from typing import TypeVar
    34  from typing import Union
    35  
    36  import numpy as np
    37  
    38  from apache_beam import typehints
    39  from apache_beam.transforms import core
    40  from apache_beam.transforms import cy_combiners
    41  from apache_beam.transforms import ptransform
    42  from apache_beam.transforms import window
    43  from apache_beam.transforms.display import DisplayDataItem
    44  from apache_beam.typehints import with_input_types
    45  from apache_beam.typehints import with_output_types
    46  from apache_beam.utils.timestamp import Duration
    47  from apache_beam.utils.timestamp import Timestamp
    48  
    49  __all__ = [
    50      'Count',
    51      'Mean',
    52      'Sample',
    53      'Top',
    54      'ToDict',
    55      'ToList',
    56      'ToSet',
    57      'Latest',
    58      'CountCombineFn',
    59      'MeanCombineFn',
    60      'SampleCombineFn',
    61      'TopCombineFn',
    62      'ToDictCombineFn',
    63      'ToListCombineFn',
    64      'ToSetCombineFn',
    65      'LatestCombineFn',
    66  ]
    67  
    68  # Type variables
    69  T = TypeVar('T')
    70  K = TypeVar('K')
    71  V = TypeVar('V')
    72  TimestampType = Union[int, float, Timestamp, Duration]
    73  
    74  
    75  class CombinerWithoutDefaults(ptransform.PTransform):
    76    """Super class to inherit without_defaults to built-in Combiners."""
    77    def __init__(self, has_defaults=True):
    78      super().__init__()
    79      self.has_defaults = has_defaults
    80  
    81    def with_defaults(self, has_defaults=True):
    82      new = copy.copy(self)
    83      new.has_defaults = has_defaults
    84      return new
    85  
    86    def without_defaults(self):
    87      return self.with_defaults(False)
    88  
    89  
    90  class Mean(object):
    91    """Combiners for computing arithmetic means of elements."""
    92    class Globally(CombinerWithoutDefaults):
    93      """combiners.Mean.Globally computes the arithmetic mean of the elements."""
    94      def expand(self, pcoll):
    95        if self.has_defaults:
    96          return pcoll | core.CombineGlobally(MeanCombineFn())
    97        else:
    98          return pcoll | core.CombineGlobally(MeanCombineFn()).without_defaults()
    99  
   100    class PerKey(ptransform.PTransform):
   101      """combiners.Mean.PerKey finds the means of the values for each key."""
   102      def expand(self, pcoll):
   103        return pcoll | core.CombinePerKey(MeanCombineFn())
   104  
   105  
   106  # TODO(laolu): This type signature is overly restrictive. This should be
   107  # more general.
   108  @with_input_types(Union[float, int, np.int64, np.float64])
   109  @with_output_types(float)
   110  class MeanCombineFn(core.CombineFn):
   111    """CombineFn for computing an arithmetic mean."""
   112    def create_accumulator(self):
   113      return (0, 0)
   114  
   115    def add_input(self, sum_count, element):
   116      (sum_, count) = sum_count
   117      return sum_ + element, count + 1
   118  
   119    def merge_accumulators(self, accumulators):
   120      sums, counts = zip(*accumulators)
   121      return sum(sums), sum(counts)
   122  
   123    def extract_output(self, sum_count):
   124      (sum_, count) = sum_count
   125      if count == 0:
   126        return float('NaN')
   127      return sum_ / float(count)
   128  
   129    def for_input_type(self, input_type):
   130      if input_type is int:
   131        return cy_combiners.MeanInt64Fn()
   132      elif input_type is float:
   133        return cy_combiners.MeanFloatFn()
   134      return self
   135  
   136  
   137  class Count(object):
   138    """Combiners for counting elements."""
   139    @with_input_types(T)
   140    @with_output_types(int)
   141    class Globally(CombinerWithoutDefaults):
   142      """combiners.Count.Globally counts the total number of elements."""
   143      def expand(self, pcoll):
   144        if self.has_defaults:
   145          return pcoll | core.CombineGlobally(CountCombineFn())
   146        else:
   147          return pcoll | core.CombineGlobally(CountCombineFn()).without_defaults()
   148  
   149    @with_input_types(Tuple[K, V])
   150    @with_output_types(Tuple[K, int])
   151    class PerKey(ptransform.PTransform):
   152      """combiners.Count.PerKey counts how many elements each unique key has."""
   153      def expand(self, pcoll):
   154        return pcoll | core.CombinePerKey(CountCombineFn())
   155  
   156    @with_input_types(T)
   157    @with_output_types(Tuple[T, int])
   158    class PerElement(ptransform.PTransform):
   159      """combiners.Count.PerElement counts how many times each element occurs."""
   160      def expand(self, pcoll):
   161        paired_with_void_type = typehints.Tuple[pcoll.element_type, Any]
   162        output_type = typehints.KV[pcoll.element_type, int]
   163        return (
   164            pcoll
   165            | (
   166                '%s:PairWithVoid' % self.label >> core.Map(
   167                    lambda x: (x, None)).with_output_types(paired_with_void_type))
   168            | core.CombinePerKey(CountCombineFn()).with_output_types(output_type))
   169  
   170  
   171  @with_input_types(Any)
   172  @with_output_types(int)
   173  class CountCombineFn(core.CombineFn):
   174    """CombineFn for computing PCollection size."""
   175    def create_accumulator(self):
   176      return 0
   177  
   178    def add_input(self, accumulator, element):
   179      return accumulator + 1
   180  
   181    def add_inputs(self, accumulator, elements):
   182      return accumulator + len(list(elements))
   183  
   184    def merge_accumulators(self, accumulators):
   185      return sum(accumulators)
   186  
   187    def extract_output(self, accumulator):
   188      return accumulator
   189  
   190  
   191  class Top(object):
   192    """Combiners for obtaining extremal elements."""
   193  
   194    # pylint: disable=no-self-argument
   195    @with_input_types(T)
   196    @with_output_types(List[T])
   197    class Of(CombinerWithoutDefaults):
   198      """Returns the n greatest elements in the PCollection.
   199  
   200      This transform will retrieve the n greatest elements in the PCollection
   201      to which it is applied, where "greatest" is determined by a
   202      function supplied as the `key` or `reverse` arguments.
   203      """
   204      def __init__(self, n, key=None, reverse=False):
   205        """Creates a global Top operation.
   206  
   207        The arguments 'key' and 'reverse' may be passed as keyword arguments,
   208        and have the same meaning as for Python's sort functions.
   209  
   210        Args:
   211          n: number of elements to extract from pcoll.
   212          key: (optional) a mapping of elements to a comparable key, similar to
   213              the key argument of Python's sorting methods.
   214          reverse: (optional) whether to order things smallest to largest, rather
   215              than largest to smallest
   216        """
   217        super().__init__()
   218        self._n = n
   219        self._key = key
   220        self._reverse = reverse
   221  
   222      def default_label(self):
   223        return 'Top(%d)' % self._n
   224  
   225      def expand(self, pcoll):
   226        if pcoll.windowing.is_default():
   227          # This is a more efficient global algorithm.
   228          top_per_bundle = pcoll | core.ParDo(
   229              _TopPerBundle(self._n, self._key, self._reverse))
   230          # If pcoll is empty, we can't guarantee that top_per_bundle
   231          # won't be empty, so inject at least one empty accumulator
   232          # so that downstream is guaranteed to produce non-empty output.
   233          empty_bundle = (
   234              pcoll.pipeline | core.Create([(None, [])]).with_output_types(
   235                  top_per_bundle.element_type))
   236          return ((top_per_bundle, empty_bundle) | core.Flatten()
   237                  | core.GroupByKey()
   238                  | core.ParDo(
   239                      _MergeTopPerBundle(self._n, self._key, self._reverse)))
   240        else:
   241          if self.has_defaults:
   242            return pcoll | core.CombineGlobally(
   243                TopCombineFn(self._n, self._key, self._reverse))
   244          else:
   245            return pcoll | core.CombineGlobally(
   246                TopCombineFn(self._n, self._key,
   247                             self._reverse)).without_defaults()
   248  
   249    @with_input_types(Tuple[K, V])
   250    @with_output_types(Tuple[K, List[V]])
   251    class PerKey(ptransform.PTransform):
   252      """Identifies the N greatest elements associated with each key.
   253  
   254      This transform will produce a PCollection mapping unique keys in the input
   255      PCollection to the n greatest elements with which they are associated, where
   256      "greatest" is determined by a function supplied as the `key` or
   257      `reverse` arguments.
   258      """
   259      def __init__(self, n, key=None, reverse=False):
   260        """Creates a per-key Top operation.
   261  
   262        The arguments 'key' and 'reverse' may be passed as keyword arguments,
   263        and have the same meaning as for Python's sort functions.
   264  
   265        Args:
   266          n: number of elements to extract from pcoll.
   267          key: (optional) a mapping of elements to a comparable key, similar to
   268              the key argument of Python's sorting methods.
   269          reverse: (optional) whether to order things smallest to largest, rather
   270              than largest to smallest
   271        """
   272        self._n = n
   273        self._key = key
   274        self._reverse = reverse
   275  
   276      def default_label(self):
   277        return 'TopPerKey(%d)' % self._n
   278  
   279      def expand(self, pcoll):
   280        """Expands the transform.
   281  
   282        Raises TypeCheckError: If the output type of the input PCollection is not
   283        compatible with Tuple[A, B].
   284  
   285        Args:
   286          pcoll: PCollection to process
   287  
   288        Returns:
   289          the PCollection containing the result.
   290        """
   291        return pcoll | core.CombinePerKey(
   292            TopCombineFn(self._n, self._key, self._reverse))
   293  
   294    @staticmethod
   295    @ptransform.ptransform_fn
   296    def Largest(pcoll, n, has_defaults=True, key=None):
   297      """Obtain a list of the greatest N elements in a PCollection."""
   298      if has_defaults:
   299        return pcoll | Top.Of(n, key)
   300      else:
   301        return pcoll | Top.Of(n, key).without_defaults()
   302  
   303    @staticmethod
   304    @ptransform.ptransform_fn
   305    def Smallest(pcoll, n, has_defaults=True, key=None):
   306      """Obtain a list of the least N elements in a PCollection."""
   307      if has_defaults:
   308        return pcoll | Top.Of(n, key, reverse=True)
   309      else:
   310        return pcoll | Top.Of(n, key, reverse=True).without_defaults()
   311  
   312    @staticmethod
   313    @ptransform.ptransform_fn
   314    def LargestPerKey(pcoll, n, key=None):
   315      """Identifies the N greatest elements associated with each key."""
   316      return pcoll | Top.PerKey(n, key)
   317  
   318    @staticmethod
   319    @ptransform.ptransform_fn
   320    def SmallestPerKey(pcoll, n, *, key=None, reverse=None):
   321      """Identifies the N least elements associated with each key."""
   322      return pcoll | Top.PerKey(n, key, reverse=True)
   323  
   324  
   325  @with_input_types(T)
   326  @with_output_types(Tuple[None, List[T]])
   327  class _TopPerBundle(core.DoFn):
   328    def __init__(self, n, key, reverse):
   329      self._n = n
   330      self._compare = operator.gt if reverse else None
   331      self._key = key
   332  
   333    def start_bundle(self):
   334      self._heap = []
   335  
   336    def process(self, element):
   337      if self._compare or self._key:
   338        element = cy_combiners.ComparableValue(element, self._compare, self._key)
   339      if len(self._heap) < self._n:
   340        heapq.heappush(self._heap, element)
   341      else:
   342        heapq.heappushpop(self._heap, element)
   343  
   344    def finish_bundle(self):
   345      # Though sorting here results in more total work, this allows us to
   346      # skip most elements in the reducer.
   347      # Essentially, given s map bundles, we are trading about O(sn) compares in
   348      # the (single) reducer for O(sn log n) compares across all mappers.
   349      self._heap.sort()
   350  
   351      # Unwrap to avoid serialization via pickle.
   352      if self._compare or self._key:
   353        yield window.GlobalWindows.windowed_value(
   354            (None, [wrapper.value for wrapper in self._heap]))
   355      else:
   356        yield window.GlobalWindows.windowed_value((None, self._heap))
   357  
   358  
   359  @with_input_types(Tuple[None, Iterable[List[T]]])
   360  @with_output_types(List[T])
   361  class _MergeTopPerBundle(core.DoFn):
   362    def __init__(self, n, key, reverse):
   363      self._n = n
   364      self._compare = operator.gt if reverse else None
   365      self._key = key
   366  
   367    def process(self, key_and_bundles):
   368      _, bundles = key_and_bundles
   369  
   370      def push(hp, e):
   371        if len(hp) < self._n:
   372          heapq.heappush(hp, e)
   373          return False
   374        elif e < hp[0]:
   375          # Because _TopPerBundle returns sorted lists, all other elements
   376          # will also be smaller.
   377          return True
   378        else:
   379          heapq.heappushpop(hp, e)
   380          return False
   381  
   382      if self._compare or self._key:
   383        heapc = []  # type: List[cy_combiners.ComparableValue]
   384        for bundle in bundles:
   385          if not heapc:
   386            heapc = [
   387                cy_combiners.ComparableValue(element, self._compare, self._key)
   388                for element in bundle
   389            ]
   390            continue
   391          # TODO(https://github.com/apache/beam/issues/21205): Remove this
   392          # workaround once legacy dataflow correctly handles coders with
   393          # combiner packing and/or is deprecated.
   394          if not isinstance(bundle, list):
   395            bundle = list(bundle)
   396          for element in reversed(bundle):
   397            if push(heapc,
   398                    cy_combiners.ComparableValue(element,
   399                                                 self._compare,
   400                                                 self._key)):
   401              break
   402        heapc.sort()
   403        yield [wrapper.value for wrapper in reversed(heapc)]
   404  
   405      else:
   406        heap = []
   407        for bundle in bundles:
   408          # TODO(https://github.com/apache/beam/issues/21205): Remove this
   409          # workaround once legacy dataflow correctly handles coders with
   410          # combiner packing and/or is deprecated.
   411          if not isinstance(bundle, list):
   412            bundle = list(bundle)
   413          if not heap:
   414            heap = bundle
   415            continue
   416          for element in reversed(bundle):
   417            if push(heap, element):
   418              break
   419        heap.sort()
   420        yield heap[::-1]
   421  
   422  
   423  @with_input_types(T)
   424  @with_output_types(List[T])
   425  class TopCombineFn(core.CombineFn):
   426    """CombineFn doing the combining for all of the Top transforms.
   427  
   428    This CombineFn uses a `key` or `reverse` operator to rank the elements.
   429  
   430    Args:
   431      key: (optional) a mapping of elements to a comparable key, similar to
   432          the key argument of Python's sorting methods.
   433      reverse: (optional) whether to order things smallest to largest, rather
   434          than largest to smallest
   435    """
   436    def __init__(self, n, key=None, reverse=False):
   437      self._n = n
   438      self._compare = operator.gt if reverse else operator.lt
   439      self._key = key
   440  
   441    def _hydrated_heap(self, heap):
   442      if heap:
   443        first = heap[0]
   444        if isinstance(first, cy_combiners.ComparableValue):
   445          if first.requires_hydration:
   446            for comparable in heap:
   447              assert comparable.requires_hydration
   448              comparable.hydrate(self._compare, self._key)
   449              assert not comparable.requires_hydration
   450            return heap
   451          else:
   452            return heap
   453        else:
   454          return [
   455              cy_combiners.ComparableValue(element, self._compare, self._key)
   456              for element in heap
   457          ]
   458      else:
   459        return heap
   460  
   461    def display_data(self):
   462      return {
   463          'n': self._n,
   464          'compare': DisplayDataItem(
   465              self._compare.__name__ if hasattr(self._compare, '__name__') else
   466              self._compare.__class__.__name__).drop_if_none()
   467      }
   468  
   469    # The accumulator type is a tuple
   470    # (bool, Union[List[T], List[ComparableValue[T]])
   471    # where the boolean indicates whether the second slot contains a List of T
   472    # (False) or List of ComparableValue[T] (True). In either case, the List
   473    # maintains heap invariance. When the contents of the List are
   474    # ComparableValue[T] they either all 'requires_hydration' or none do.
   475    # This accumulator representation allows us to minimize the data encoding
   476    # overheads. Creation of ComparableValues is elided for performance reasons
   477    # when there is no need for complicated comparison functions.
   478    def create_accumulator(self, *args, **kwargs):
   479      return (False, [])
   480  
   481    def add_input(self, accumulator, element, *args, **kwargs):
   482      # Caching to avoid paying the price of variadic expansion of args / kwargs
   483      # when it's not needed (for the 'if' case below).
   484      holds_comparables, heap = accumulator
   485      if self._compare is not operator.lt or self._key:
   486        heap = self._hydrated_heap(heap)
   487        holds_comparables = True
   488      else:
   489        assert not holds_comparables
   490  
   491      comparable = (
   492          cy_combiners.ComparableValue(element, self._compare, self._key)
   493          if holds_comparables else element)
   494  
   495      if len(heap) < self._n:
   496        heapq.heappush(heap, comparable)
   497      else:
   498        heapq.heappushpop(heap, comparable)
   499      return (holds_comparables, heap)
   500  
   501    def merge_accumulators(self, accumulators, *args, **kwargs):
   502      result_heap = None
   503      holds_comparables = None
   504      for accumulator in accumulators:
   505        holds_comparables, heap = accumulator
   506        if self._compare is not operator.lt or self._key:
   507          heap = self._hydrated_heap(heap)
   508          holds_comparables = True
   509        else:
   510          assert not holds_comparables
   511  
   512        if result_heap is None:
   513          result_heap = heap
   514        else:
   515          for comparable in heap:
   516            _, result_heap = self.add_input(
   517                (holds_comparables, result_heap),
   518                comparable.value if holds_comparables else comparable)
   519  
   520      assert result_heap is not None and holds_comparables is not None
   521      return (holds_comparables, result_heap)
   522  
   523    def compact(self, accumulator, *args, **kwargs):
   524      holds_comparables, heap = accumulator
   525      # Unwrap to avoid serialization via pickle.
   526      if holds_comparables:
   527        return (False, [comparable.value for comparable in heap])
   528      else:
   529        return accumulator
   530  
   531    def extract_output(self, accumulator, *args, **kwargs):
   532      holds_comparables, heap = accumulator
   533      if self._compare is not operator.lt or self._key:
   534        if not holds_comparables:
   535          heap = self._hydrated_heap(heap)
   536          holds_comparables = True
   537      else:
   538        assert not holds_comparables
   539  
   540      assert len(heap) <= self._n
   541      heap.sort(reverse=True)
   542      return [
   543          comparable.value if holds_comparables else comparable
   544          for comparable in heap
   545      ]
   546  
   547  
   548  class Largest(TopCombineFn):
   549    def default_label(self):
   550      return 'Largest(%s)' % self._n
   551  
   552  
   553  class Smallest(TopCombineFn):
   554    def __init__(self, n):
   555      super().__init__(n, reverse=True)
   556  
   557    def default_label(self):
   558      return 'Smallest(%s)' % self._n
   559  
   560  
   561  class Sample(object):
   562    """Combiners for sampling n elements without replacement."""
   563  
   564    # pylint: disable=no-self-argument
   565  
   566    @with_input_types(T)
   567    @with_output_types(List[T])
   568    class FixedSizeGlobally(CombinerWithoutDefaults):
   569      """Sample n elements from the input PCollection without replacement."""
   570      def __init__(self, n):
   571        super().__init__()
   572        self._n = n
   573  
   574      def expand(self, pcoll):
   575        if self.has_defaults:
   576          return pcoll | core.CombineGlobally(SampleCombineFn(self._n))
   577        else:
   578          return pcoll | core.CombineGlobally(SampleCombineFn(
   579              self._n)).without_defaults()
   580  
   581      def display_data(self):
   582        return {'n': self._n}
   583  
   584      def default_label(self):
   585        return 'FixedSizeGlobally(%d)' % self._n
   586  
   587    @with_input_types(Tuple[K, V])
   588    @with_output_types(Tuple[K, List[V]])
   589    class FixedSizePerKey(ptransform.PTransform):
   590      """Sample n elements associated with each key without replacement."""
   591      def __init__(self, n):
   592        self._n = n
   593  
   594      def expand(self, pcoll):
   595        return pcoll | core.CombinePerKey(SampleCombineFn(self._n))
   596  
   597      def display_data(self):
   598        return {'n': self._n}
   599  
   600      def default_label(self):
   601        return 'FixedSizePerKey(%d)' % self._n
   602  
   603  
   604  @with_input_types(T)
   605  @with_output_types(List[T])
   606  class SampleCombineFn(core.CombineFn):
   607    """CombineFn for all Sample transforms."""
   608    def __init__(self, n):
   609      super().__init__()
   610      # Most of this combiner's work is done by a TopCombineFn. We could just
   611      # subclass TopCombineFn to make this class, but since sampling is not
   612      # really a kind of Top operation, we use a TopCombineFn instance as a
   613      # helper instead.
   614      self._top_combiner = TopCombineFn(n)
   615  
   616    def setup(self):
   617      self._top_combiner.setup()
   618  
   619    def create_accumulator(self):
   620      return self._top_combiner.create_accumulator()
   621  
   622    def add_input(self, heap, element):
   623      # Before passing elements to the Top combiner, we pair them with random
   624      # numbers. The elements with the n largest random number "keys" will be
   625      # selected for the output.
   626      return self._top_combiner.add_input(heap, (random.random(), element))
   627  
   628    def merge_accumulators(self, heaps):
   629      return self._top_combiner.merge_accumulators(heaps)
   630  
   631    def compact(self, heap):
   632      return self._top_combiner.compact(heap)
   633  
   634    def extract_output(self, heap):
   635      # Here we strip off the random number keys we added in add_input.
   636      return [e for _, e in self._top_combiner.extract_output(heap)]
   637  
   638    def teardown(self):
   639      self._top_combiner.teardown()
   640  
   641  
   642  class _TupleCombineFnBase(core.CombineFn):
   643    def __init__(self, *combiners, merge_accumulators_batch_size=None):
   644      self._combiners = [core.CombineFn.maybe_from_callable(c) for c in combiners]
   645      self._named_combiners = combiners
   646      # If the `merge_accumulators_batch_size` value is not specified, we chose a
   647      # bounded default that is inversely proportional to the number of
   648      # accumulators in merged tuples.
   649      num_combiners = max(1, len(combiners))
   650      self._merge_accumulators_batch_size = (
   651          merge_accumulators_batch_size or max(10, 1000 // num_combiners))
   652  
   653    def display_data(self):
   654      combiners = [
   655          c.__name__ if hasattr(c, '__name__') else c.__class__.__name__
   656          for c in self._named_combiners
   657      ]
   658      return {
   659          'combiners': str(combiners),
   660          'merge_accumulators_batch_size': self._merge_accumulators_batch_size
   661      }
   662  
   663    def setup(self, *args, **kwargs):
   664      for c in self._combiners:
   665        c.setup(*args, **kwargs)
   666  
   667    def create_accumulator(self, *args, **kwargs):
   668      return [c.create_accumulator(*args, **kwargs) for c in self._combiners]
   669  
   670    def merge_accumulators(self, accumulators, *args, **kwargs):
   671      # Make sure that `accumulators` is an iterator (so that the position is
   672      # remembered).
   673      accumulators = iter(accumulators)
   674      result = next(accumulators)
   675      while True:
   676        # Load accumulators into memory and merge in batches to decrease peak
   677        # memory usage.
   678        accumulators_batch = [result] + list(
   679            itertools.islice(accumulators, self._merge_accumulators_batch_size))
   680        if len(accumulators_batch) == 1:
   681          break
   682        result = [
   683            c.merge_accumulators(a, *args, **kwargs) for c,
   684            a in zip(self._combiners, zip(*accumulators_batch))
   685        ]
   686      return result
   687  
   688    def compact(self, accumulator, *args, **kwargs):
   689      return [
   690          c.compact(a, *args, **kwargs) for c,
   691          a in zip(self._combiners, accumulator)
   692      ]
   693  
   694    def extract_output(self, accumulator, *args, **kwargs):
   695      return tuple(
   696          c.extract_output(a, *args, **kwargs) for c,
   697          a in zip(self._combiners, accumulator))
   698  
   699    def teardown(self, *args, **kwargs):
   700      for c in reversed(self._combiners):
   701        c.teardown(*args, **kwargs)
   702  
   703  
   704  class TupleCombineFn(_TupleCombineFnBase):
   705    """A combiner for combining tuples via a tuple of combiners.
   706  
   707    Takes as input a tuple of N CombineFns and combines N-tuples by
   708    combining the k-th element of each tuple with the k-th CombineFn,
   709    outputting a new N-tuple of combined values.
   710    """
   711    def add_input(self, accumulator, element, *args, **kwargs):
   712      return [
   713          c.add_input(a, e, *args, **kwargs) for c,
   714          a,
   715          e in zip(self._combiners, accumulator, element)
   716      ]
   717  
   718    def with_common_input(self):
   719      return SingleInputTupleCombineFn(*self._combiners)
   720  
   721  
   722  class SingleInputTupleCombineFn(_TupleCombineFnBase):
   723    """A combiner for combining a single value via a tuple of combiners.
   724  
   725    Takes as input a tuple of N CombineFns and combines elements by
   726    applying each CombineFn to each input, producing an N-tuple of
   727    the outputs corresponding to each of the N CombineFn's outputs.
   728    """
   729    def add_input(self, accumulator, element, *args, **kwargs):
   730      return [
   731          c.add_input(a, element, *args, **kwargs) for c,
   732          a in zip(self._combiners, accumulator)
   733      ]
   734  
   735  
   736  @with_input_types(T)
   737  @with_output_types(List[T])
   738  class ToList(CombinerWithoutDefaults):
   739    """A global CombineFn that condenses a PCollection into a single list."""
   740    def expand(self, pcoll):
   741      if self.has_defaults:
   742        return pcoll | self.label >> core.CombineGlobally(ToListCombineFn())
   743      else:
   744        return pcoll | self.label >> core.CombineGlobally(
   745            ToListCombineFn()).without_defaults()
   746  
   747  
   748  @with_input_types(T)
   749  @with_output_types(List[T])
   750  class ToListCombineFn(core.CombineFn):
   751    """CombineFn for to_list."""
   752    def create_accumulator(self):
   753      return []
   754  
   755    def add_input(self, accumulator, element):
   756      accumulator.append(element)
   757      return accumulator
   758  
   759    def merge_accumulators(self, accumulators):
   760      return sum(accumulators, [])
   761  
   762    def extract_output(self, accumulator):
   763      return accumulator
   764  
   765  
   766  @with_input_types(Tuple[K, V])
   767  @with_output_types(Dict[K, V])
   768  class ToDict(CombinerWithoutDefaults):
   769    """A global CombineFn that condenses a PCollection into a single dict.
   770  
   771    PCollections should consist of 2-tuples, notionally (key, value) pairs.
   772    If multiple values are associated with the same key, only one of the values
   773    will be present in the resulting dict.
   774    """
   775    def expand(self, pcoll):
   776      if self.has_defaults:
   777        return pcoll | self.label >> core.CombineGlobally(ToDictCombineFn())
   778      else:
   779        return pcoll | self.label >> core.CombineGlobally(
   780            ToDictCombineFn()).without_defaults()
   781  
   782  
   783  @with_input_types(Tuple[K, V])
   784  @with_output_types(Dict[K, V])
   785  class ToDictCombineFn(core.CombineFn):
   786    """CombineFn for to_dict."""
   787    def create_accumulator(self):
   788      return {}
   789  
   790    def add_input(self, accumulator, element):
   791      key, value = element
   792      accumulator[key] = value
   793      return accumulator
   794  
   795    def merge_accumulators(self, accumulators):
   796      result = {}
   797      for a in accumulators:
   798        result.update(a)
   799      return result
   800  
   801    def extract_output(self, accumulator):
   802      return accumulator
   803  
   804  
   805  @with_input_types(T)
   806  @with_output_types(Set[T])
   807  class ToSet(CombinerWithoutDefaults):
   808    """A global CombineFn that condenses a PCollection into a set."""
   809    def expand(self, pcoll):
   810      if self.has_defaults:
   811        return pcoll | self.label >> core.CombineGlobally(ToSetCombineFn())
   812      else:
   813        return pcoll | self.label >> core.CombineGlobally(
   814            ToSetCombineFn()).without_defaults()
   815  
   816  
   817  @with_input_types(T)
   818  @with_output_types(Set[T])
   819  class ToSetCombineFn(core.CombineFn):
   820    """CombineFn for ToSet."""
   821    def create_accumulator(self):
   822      return set()
   823  
   824    def add_input(self, accumulator, element):
   825      accumulator.add(element)
   826      return accumulator
   827  
   828    def merge_accumulators(self, accumulators):
   829      return set.union(*accumulators)
   830  
   831    def extract_output(self, accumulator):
   832      return accumulator
   833  
   834  
   835  class _CurriedFn(core.CombineFn):
   836    """Wrapped CombineFn with extra arguments."""
   837    def __init__(self, fn, args, kwargs):
   838      self.fn = fn
   839      self.args = args
   840      self.kwargs = kwargs
   841  
   842    def setup(self):
   843      self.fn.setup(*self.args, **self.kwargs)
   844  
   845    def create_accumulator(self):
   846      return self.fn.create_accumulator(*self.args, **self.kwargs)
   847  
   848    def add_input(self, accumulator, element):
   849      return self.fn.add_input(accumulator, element, *self.args, **self.kwargs)
   850  
   851    def merge_accumulators(self, accumulators):
   852      return self.fn.merge_accumulators(accumulators, *self.args, **self.kwargs)
   853  
   854    def compact(self, accumulator):
   855      return self.fn.compact(accumulator, *self.args, **self.kwargs)
   856  
   857    def extract_output(self, accumulator):
   858      return self.fn.extract_output(accumulator, *self.args, **self.kwargs)
   859  
   860    def teardown(self):
   861      self.fn.teardown(*self.args, **self.kwargs)
   862  
   863    def apply(self, elements):
   864      return self.fn.apply(elements, *self.args, **self.kwargs)
   865  
   866  
   867  def curry_combine_fn(fn, args, kwargs):
   868    if not args and not kwargs:
   869      return fn
   870    else:
   871      return _CurriedFn(fn, args, kwargs)
   872  
   873  
   874  class PhasedCombineFnExecutor(object):
   875    """Executor for phases of combine operations."""
   876    def __init__(self, phase, fn, args, kwargs):
   877  
   878      self.combine_fn = curry_combine_fn(fn, args, kwargs)
   879  
   880      if phase == 'all':
   881        self.apply = self.full_combine
   882      elif phase == 'add':
   883        self.apply = self.add_only
   884      elif phase == 'merge':
   885        self.apply = self.merge_only
   886      elif phase == 'extract':
   887        self.apply = self.extract_only
   888      elif phase == 'convert':
   889        self.apply = self.convert_to_accumulator
   890      else:
   891        raise ValueError('Unexpected phase: %s' % phase)
   892  
   893    def full_combine(self, elements):
   894      return self.combine_fn.apply(elements)
   895  
   896    def add_only(self, elements):
   897      return self.combine_fn.add_inputs(
   898          self.combine_fn.create_accumulator(), elements)
   899  
   900    def merge_only(self, accumulators):
   901      return self.combine_fn.merge_accumulators(accumulators)
   902  
   903    def extract_only(self, accumulator):
   904      return self.combine_fn.extract_output(accumulator)
   905  
   906    def convert_to_accumulator(self, element):
   907      return self.combine_fn.add_input(
   908          self.combine_fn.create_accumulator(), element)
   909  
   910  
   911  class Latest(object):
   912    """Combiners for computing the latest element"""
   913    @with_input_types(T)
   914    @with_output_types(T)
   915    class Globally(CombinerWithoutDefaults):
   916      """Compute the element with the latest timestamp from a
   917      PCollection."""
   918      @staticmethod
   919      def add_timestamp(element, timestamp=core.DoFn.TimestampParam):
   920        return [(element, timestamp)]
   921  
   922      def expand(self, pcoll):
   923        if self.has_defaults:
   924          return (
   925              pcoll
   926              | core.ParDo(self.add_timestamp).with_output_types(
   927                  Tuple[T, TimestampType])
   928              | core.CombineGlobally(LatestCombineFn()))
   929        else:
   930          return (
   931              pcoll
   932              | core.ParDo(self.add_timestamp).with_output_types(
   933                  Tuple[T, TimestampType])
   934              | core.CombineGlobally(LatestCombineFn()).without_defaults())
   935  
   936    @with_input_types(Tuple[K, V])
   937    @with_output_types(Tuple[K, V])
   938    class PerKey(ptransform.PTransform):
   939      """Compute elements with the latest timestamp for each key
   940      from a keyed PCollection"""
   941      @staticmethod
   942      def add_timestamp(element, timestamp=core.DoFn.TimestampParam):
   943        key, value = element
   944        return [(key, (value, timestamp))]
   945  
   946      def expand(self, pcoll):
   947        return (
   948            pcoll
   949            | core.ParDo(self.add_timestamp).with_output_types(
   950                Tuple[K, Tuple[T, TimestampType]])
   951            | core.CombinePerKey(LatestCombineFn()))
   952  
   953  
   954  @with_input_types(Tuple[T, TimestampType])
   955  @with_output_types(T)
   956  class LatestCombineFn(core.CombineFn):
   957    """CombineFn to get the element with the latest timestamp
   958    from a PCollection."""
   959    def create_accumulator(self):
   960      return (None, window.MIN_TIMESTAMP)
   961  
   962    def add_input(self, accumulator, element):
   963      if accumulator[1] > element[1]:
   964        return accumulator
   965      else:
   966        return element
   967  
   968    def merge_accumulators(self, accumulators):
   969      result = self.create_accumulator()
   970      for accumulator in accumulators:
   971        result = self.add_input(result, accumulator)
   972      return result
   973  
   974    def extract_output(self, accumulator):
   975      return accumulator[0]