github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/transforms/cy_combiners.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  # cython: language_level=3
    19  
    20  """A library of basic cythonized CombineFn subclasses.
    21  
    22  For internal use only; no backwards-compatibility guarantees.
    23  """
    24  
    25  # pytype: skip-file
    26  
    27  import operator
    28  
    29  from apache_beam.transforms import core
    30  
    31  try:
    32    from apache_beam.transforms.cy_dataflow_distribution_counter import DataflowDistributionCounter
    33  except ImportError:
    34    from apache_beam.transforms.py_dataflow_distribution_counter import DataflowDistributionCounter
    35  
    36  
    37  class AccumulatorCombineFn(core.CombineFn):
    38    # singleton?
    39    def create_accumulator(self):
    40      return self._accumulator_type()
    41  
    42    @staticmethod
    43    def add_input(accumulator, element):
    44      accumulator.add_input(element)
    45      return accumulator
    46  
    47    def merge_accumulators(self, accumulators):
    48      accumulator = self._accumulator_type()
    49      accumulator.merge(accumulators)
    50      return accumulator
    51  
    52    @staticmethod
    53    def extract_output(accumulator):
    54      return accumulator.extract_output()
    55  
    56    def __eq__(self, other):
    57      return (
    58          isinstance(other, AccumulatorCombineFn) and
    59          self._accumulator_type is other._accumulator_type)
    60  
    61    def __hash__(self):
    62      return hash(self._accumulator_type)
    63  
    64  
    65  _63 = 63  # Avoid large literals in C source code.
    66  globals()['INT64_MAX'] = 2**_63 - 1
    67  globals()['INT64_MIN'] = -2**_63
    68  
    69  
    70  class CountAccumulator(object):
    71    def __init__(self):
    72      self.value = 0
    73  
    74    def add_input(self, unused_element):
    75      self.value += 1
    76  
    77    def add_input_n(self, unused_element, n):
    78      self.value += n
    79  
    80    def merge(self, accumulators):
    81      for accumulator in accumulators:
    82        self.value += accumulator.value
    83  
    84    def extract_output(self):
    85      return self.value
    86  
    87  
    88  class SumInt64Accumulator(object):
    89    def __init__(self):
    90      self.value = 0
    91  
    92    def add_input(self, element):
    93      global INT64_MAX, INT64_MIN  # pylint: disable=global-variable-not-assigned
    94      element = int(element)
    95      if not INT64_MIN <= element <= INT64_MAX:
    96        raise OverflowError(element)
    97      self.value += element
    98  
    99    def add_input_n(self, element, n):
   100      global INT64_MAX, INT64_MIN  # pylint: disable=global-variable-not-assigned
   101      element = int(element)
   102      if not INT64_MIN <= element <= INT64_MAX:
   103        raise OverflowError(element)
   104      self.value += element * n
   105  
   106    def merge(self, accumulators):
   107      for accumulator in accumulators:
   108        self.value += accumulator.value
   109  
   110    def extract_output(self):
   111      if not INT64_MIN <= self.value <= INT64_MAX:
   112        self.value %= 2**64
   113        if self.value >= INT64_MAX:
   114          self.value -= 2**64
   115      return self.value
   116  
   117  
   118  class MinInt64Accumulator(object):
   119    def __init__(self):
   120      self.value = INT64_MAX
   121  
   122    def add_input(self, element):
   123      element = int(element)
   124      if not INT64_MIN <= element <= INT64_MAX:
   125        raise OverflowError(element)
   126      if element < self.value:
   127        self.value = element
   128  
   129    def add_input_n(self, element, unused_n):
   130      self.add_input(element)
   131  
   132    def merge(self, accumulators):
   133      for accumulator in accumulators:
   134        if accumulator.value < self.value:
   135          self.value = accumulator.value
   136  
   137    def extract_output(self):
   138      return self.value
   139  
   140  
   141  class MaxInt64Accumulator(object):
   142    def __init__(self):
   143      self.value = INT64_MIN
   144  
   145    def add_input(self, element):
   146      element = int(element)
   147      if not INT64_MIN <= element <= INT64_MAX:
   148        raise OverflowError(element)
   149      if element > self.value:
   150        self.value = element
   151  
   152    def add_input_n(self, element, unused_n):
   153      self.add_input(element)
   154  
   155    def merge(self, accumulators):
   156      for accumulator in accumulators:
   157        if accumulator.value > self.value:
   158          self.value = accumulator.value
   159  
   160    def extract_output(self):
   161      return self.value
   162  
   163  
   164  class MeanInt64Accumulator(object):
   165    def __init__(self):
   166      self.sum = 0
   167      self.count = 0
   168  
   169    def add_input(self, element):
   170      element = int(element)
   171      if not INT64_MIN <= element <= INT64_MAX:
   172        raise OverflowError(element)
   173      self.sum += element
   174      self.count += 1
   175  
   176    def add_input_n(self, element, n):
   177      element = int(element)
   178      if not INT64_MIN <= element <= INT64_MAX:
   179        raise OverflowError(element)
   180      self.sum += element * n
   181      self.count += n
   182  
   183    def merge(self, accumulators):
   184      for accumulator in accumulators:
   185        self.sum += accumulator.sum
   186        self.count += accumulator.count
   187  
   188    def extract_output(self):
   189      if not INT64_MIN <= self.sum <= INT64_MAX:
   190        self.sum %= 2**64
   191        if self.sum >= INT64_MAX:
   192          self.sum -= 2**64
   193      return self.sum // self.count if self.count else _NAN
   194  
   195  
   196  class DistributionInt64Accumulator(object):
   197    def __init__(self):
   198      self.sum = 0
   199      self.count = 0
   200      self.min = INT64_MAX
   201      self.max = INT64_MIN
   202  
   203    def add_input(self, element):
   204      element = int(element)
   205      if not INT64_MIN <= element <= INT64_MAX:
   206        raise OverflowError(element)
   207      self.sum += element
   208      self.count += 1
   209      self.min = min(self.min, element)
   210      self.max = max(self.max, element)
   211  
   212    def add_input_n(self, element, n):
   213      element = int(element)
   214      if not INT64_MIN <= element <= INT64_MAX:
   215        raise OverflowError(element)
   216      self.sum += element * n
   217      self.count += n
   218      self.min = min(self.min, element)
   219      self.max = max(self.max, element)
   220  
   221    def merge(self, accumulators):
   222      for accumulator in accumulators:
   223        self.sum += accumulator.sum
   224        self.count += accumulator.count
   225        self.min = min(self.min, accumulator.min)
   226        self.max = max(self.max, accumulator.max)
   227  
   228    def extract_output(self):
   229      if not INT64_MIN <= self.sum <= INT64_MAX:
   230        self.sum %= 2**64
   231        if self.sum >= INT64_MAX:
   232          self.sum -= 2**64
   233      mean = self.sum // self.count if self.count else _NAN
   234      return mean, self.sum, self.count, self.min, self.max
   235  
   236  
   237  class CountCombineFn(AccumulatorCombineFn):
   238    _accumulator_type = CountAccumulator
   239  
   240  
   241  class SumInt64Fn(AccumulatorCombineFn):
   242    _accumulator_type = SumInt64Accumulator
   243  
   244  
   245  class MinInt64Fn(AccumulatorCombineFn):
   246    _accumulator_type = MinInt64Accumulator
   247  
   248  
   249  class MaxInt64Fn(AccumulatorCombineFn):
   250    _accumulator_type = MaxInt64Accumulator
   251  
   252  
   253  class MeanInt64Fn(AccumulatorCombineFn):
   254    _accumulator_type = MeanInt64Accumulator
   255  
   256  
   257  class DistributionInt64Fn(AccumulatorCombineFn):
   258    _accumulator_type = DistributionInt64Accumulator
   259  
   260  
   261  _POS_INF = float('inf')
   262  _NEG_INF = float('-inf')
   263  _NAN = float('nan')
   264  
   265  
   266  class SumDoubleAccumulator(object):
   267    def __init__(self):
   268      self.value = 0
   269  
   270    def add_input(self, element):
   271      element = float(element)
   272      self.value += element
   273  
   274    def merge(self, accumulators):
   275      for accumulator in accumulators:
   276        self.value += accumulator.value
   277  
   278    def extract_output(self):
   279      return self.value
   280  
   281  
   282  class MinDoubleAccumulator(object):
   283    def __init__(self):
   284      self.value = _POS_INF
   285  
   286    def add_input(self, element):
   287      element = float(element)
   288      if element < self.value:
   289        self.value = element
   290  
   291    def merge(self, accumulators):
   292      for accumulator in accumulators:
   293        if accumulator.value < self.value:
   294          self.value = accumulator.value
   295  
   296    def extract_output(self):
   297      return self.value
   298  
   299  
   300  class MaxDoubleAccumulator(object):
   301    def __init__(self):
   302      self.value = _NEG_INF
   303  
   304    def add_input(self, element):
   305      element = float(element)
   306      if element > self.value:
   307        self.value = element
   308  
   309    def merge(self, accumulators):
   310      for accumulator in accumulators:
   311        if accumulator.value > self.value:
   312          self.value = accumulator.value
   313  
   314    def extract_output(self):
   315      return self.value
   316  
   317  
   318  class MeanDoubleAccumulator(object):
   319    def __init__(self):
   320      self.sum = 0
   321      self.count = 0
   322  
   323    def add_input(self, element):
   324      element = float(element)
   325      self.sum += element
   326      self.count += 1
   327  
   328    def merge(self, accumulators):
   329      for accumulator in accumulators:
   330        self.sum += accumulator.sum
   331        self.count += accumulator.count
   332  
   333    def extract_output(self):
   334      return self.sum // self.count if self.count else _NAN
   335  
   336  
   337  class SumFloatFn(AccumulatorCombineFn):
   338    _accumulator_type = SumDoubleAccumulator
   339  
   340  
   341  class MinFloatFn(AccumulatorCombineFn):
   342    _accumulator_type = MinDoubleAccumulator
   343  
   344  
   345  class MaxFloatFn(AccumulatorCombineFn):
   346    _accumulator_type = MaxDoubleAccumulator
   347  
   348  
   349  class MeanFloatFn(AccumulatorCombineFn):
   350    _accumulator_type = MeanDoubleAccumulator
   351  
   352  
   353  class AllAccumulator(object):
   354    def __init__(self):
   355      self.value = True
   356  
   357    def add_input(self, element):
   358      self.value &= not not element
   359  
   360    def merge(self, accumulators):
   361      for accumulator in accumulators:
   362        self.value &= accumulator.value
   363  
   364    def extract_output(self):
   365      return self.value
   366  
   367  
   368  class AnyAccumulator(object):
   369    def __init__(self):
   370      self.value = False
   371  
   372    def add_input(self, element):
   373      self.value |= not not element
   374  
   375    def merge(self, accumulators):
   376      for accumulator in accumulators:
   377        self.value |= accumulator.value
   378  
   379    def extract_output(self):
   380      return self.value
   381  
   382  
   383  class AnyCombineFn(AccumulatorCombineFn):
   384    _accumulator_type = AnyAccumulator
   385  
   386  
   387  class AllCombineFn(AccumulatorCombineFn):
   388    _accumulator_type = AllAccumulator
   389  
   390  
   391  class DataflowDistributionCounterFn(AccumulatorCombineFn):
   392    """A subclass of cy_combiners.AccumulatorCombineFn.
   393  
   394    Make DataflowDistributionCounter able to report to Dataflow service via
   395    CounterFactory.
   396  
   397    When cythonized DataflowDistributinoCounter available, make
   398    CounterFn combine with cythonized module, otherwise, combine with python
   399    version.
   400    """
   401    _accumulator_type = DataflowDistributionCounter
   402  
   403  
   404  class ComparableValue(object):
   405    """A way to allow comparing elements in a rich fashion."""
   406  
   407    __slots__ = (
   408        'value', '_less_than_fn', '_comparable_value', 'requires_hydration')
   409  
   410    def __init__(self, value, less_than_fn, key_fn, _requires_hydration=False):
   411      self.value = value
   412      self.hydrate(less_than_fn, key_fn)
   413      self.requires_hydration = _requires_hydration
   414  
   415    def hydrate(self, less_than_fn, key_fn):
   416      self._less_than_fn = less_than_fn if less_than_fn else operator.lt
   417      self._comparable_value = key_fn(self.value) if key_fn else self.value
   418      self.requires_hydration = False
   419  
   420    def __lt__(self, other):
   421      assert not self.requires_hydration
   422      assert self._less_than_fn is other._less_than_fn
   423      return self._less_than_fn(self._comparable_value, other._comparable_value)
   424  
   425    def __repr__(self):
   426      return 'ComparableValue[%s]' % str(self.value)
   427  
   428    def __reduce__(self):
   429      # Since we can't pickle the Compare and Key Fn we pass None and we signify
   430      # that this object _requires_hydration.
   431      return ComparableValue, (self.value, None, None, True)