github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/transforms/stats.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  # cython: language_level=3
    19  
    20  """This module has all statistic related transforms.
    21  
    22  This ApproximateUnique class will be deprecated [1]. PLease look into using
    23  HLLCount in the zetasketch extension module [2].
    24  
    25  [1] https://lists.apache.org/thread.html/501605df5027567099b81f18c080469661fb426
    26  4a002615fa1510502%40%3Cdev.beam.apache.org%3E
    27  [2] https://beam.apache.org/releases/javadoc/2.16.0/org/apache/beam/sdk/extensio
    28  ns/zetasketch/HllCount.html
    29  """
    30  
    31  # pytype: skip-file
    32  
    33  import hashlib
    34  import heapq
    35  import itertools
    36  import logging
    37  import math
    38  import typing
    39  from typing import Any
    40  from typing import Callable
    41  from typing import List
    42  from typing import Tuple
    43  
    44  from apache_beam import coders
    45  from apache_beam import typehints
    46  from apache_beam.transforms.core import *
    47  from apache_beam.transforms.display import DisplayDataItem
    48  from apache_beam.transforms.ptransform import PTransform
    49  
    50  __all__ = [
    51      'ApproximateQuantiles',
    52      'ApproximateUnique',
    53  ]
    54  
    55  # Type variables
    56  T = typing.TypeVar('T')
    57  K = typing.TypeVar('K')
    58  V = typing.TypeVar('V')
    59  
    60  try:
    61    import mmh3  # pylint: disable=import-error
    62  
    63    def _mmh3_hash(value):
    64      # mmh3.hash64 returns two 64-bit unsigned integers
    65      return mmh3.hash64(value, seed=0, signed=False)[0]
    66  
    67    _default_hash_fn = _mmh3_hash
    68    _default_hash_fn_type = 'mmh3'
    69  except ImportError:
    70  
    71    def _md5_hash(value):
    72      # md5 is a 128-bit hash, so we truncate the hexdigest (string of 32
    73      # hexadecimal digits) to 16 digits and convert to int to get the 64-bit
    74      # integer fingerprint.
    75      return int(hashlib.md5(value).hexdigest()[:16], 16)
    76  
    77    _default_hash_fn = _md5_hash
    78    _default_hash_fn_type = 'md5'
    79  
    80  
    81  def _get_default_hash_fn():
    82    """Returns either murmurhash or md5 based on installation."""
    83    if _default_hash_fn_type == 'md5':
    84      logging.warning(
    85          'Couldn\'t find murmurhash. Install mmh3 for a faster implementation of'
    86          'ApproximateUnique.')
    87    return _default_hash_fn
    88  
    89  
    90  class ApproximateUnique(object):
    91    """
    92    Hashes input elements and uses those to extrapolate the size of the entire
    93    set of hash values by assuming the rest of the hash values are as densely
    94    distributed as the sample space.
    95    """
    96  
    97    _NO_VALUE_ERR_MSG = 'Either size or error should be set. Received {}.'
    98    _MULTI_VALUE_ERR_MSG = 'Either size or error should be set. ' \
    99                           'Received {size = %s, error = %s}.'
   100    _INPUT_SIZE_ERR_MSG = 'ApproximateUnique needs a size >= 16 for an error ' \
   101                          '<= 0.50. In general, the estimation error is about ' \
   102                          '2 / sqrt(sample_size). Received {size = %s}.'
   103    _INPUT_ERROR_ERR_MSG = 'ApproximateUnique needs an estimation error ' \
   104                           'between 0.01 and 0.50. Received {error = %s}.'
   105  
   106    @staticmethod
   107    def parse_input_params(size=None, error=None):
   108      """
   109      Check if input params are valid and return sample size.
   110  
   111      :param size: an int not smaller than 16, which we would use to estimate
   112        number of unique values.
   113      :param error: max estimation error, which is a float between 0.01 and 0.50.
   114        If error is given, sample size will be calculated from error with
   115        _get_sample_size_from_est_error function.
   116      :return: sample size
   117      :raises:
   118        ValueError: If both size and error are given, or neither is given, or
   119        values are out of range.
   120      """
   121  
   122      if None not in (size, error):
   123        raise ValueError(ApproximateUnique._MULTI_VALUE_ERR_MSG % (size, error))
   124      elif size is None and error is None:
   125        raise ValueError(ApproximateUnique._NO_VALUE_ERR_MSG)
   126      elif size is not None:
   127        if not isinstance(size, int) or size < 16:
   128          raise ValueError(ApproximateUnique._INPUT_SIZE_ERR_MSG % (size))
   129        else:
   130          return size
   131      else:
   132        if error < 0.01 or error > 0.5:
   133          raise ValueError(ApproximateUnique._INPUT_ERROR_ERR_MSG % (error))
   134        else:
   135          return ApproximateUnique._get_sample_size_from_est_error(error)
   136  
   137    @staticmethod
   138    def _get_sample_size_from_est_error(est_err):
   139      """
   140      :return: sample size
   141  
   142      Calculate sample size from estimation error
   143      """
   144      return math.ceil(4.0 / math.pow(est_err, 2.0))
   145  
   146    @typehints.with_input_types(T)
   147    @typehints.with_output_types(int)
   148    class Globally(PTransform):
   149      """ Approximate.Globally approximate number of unique values"""
   150      def __init__(self, size=None, error=None):
   151        self._sample_size = ApproximateUnique.parse_input_params(size, error)
   152  
   153      def expand(self, pcoll):
   154        coder = coders.registry.get_coder(pcoll)
   155        return pcoll \
   156               | 'CountGlobalUniqueValues' \
   157               >> (CombineGlobally(ApproximateUniqueCombineFn(self._sample_size,
   158                                                              coder)))
   159  
   160    @typehints.with_input_types(typing.Tuple[K, V])
   161    @typehints.with_output_types(typing.Tuple[K, int])
   162    class PerKey(PTransform):
   163      """ Approximate.PerKey approximate number of unique values per key"""
   164      def __init__(self, size=None, error=None):
   165        self._sample_size = ApproximateUnique.parse_input_params(size, error)
   166  
   167      def expand(self, pcoll):
   168        coder = coders.registry.get_coder(pcoll)
   169        return pcoll \
   170               | 'CountPerKeyUniqueValues' \
   171               >> (CombinePerKey(ApproximateUniqueCombineFn(self._sample_size,
   172                                                            coder)))
   173  
   174  
   175  class _LargestUnique(object):
   176    """
   177    An object to keep samples and calculate sample hash space. It is an
   178    accumulator of a combine function.
   179    """
   180    # We use unsigned 64-bit integer hashes.
   181    _HASH_SPACE_SIZE = 2.0**64
   182  
   183    def __init__(self, sample_size):
   184      self._sample_size = sample_size
   185      self._min_hash = 2.0**64
   186      self._sample_heap = []
   187      self._sample_set = set()
   188  
   189    def add(self, element):
   190      """
   191      :param an element from pcoll.
   192      :return: boolean type whether the value is in the heap
   193  
   194      Adds a value to the heap, returning whether the value is (large enough to
   195      be) in the heap.
   196      """
   197      if len(self._sample_heap) >= self._sample_size and element < self._min_hash:
   198        return False
   199  
   200      if element not in self._sample_set:
   201        self._sample_set.add(element)
   202        heapq.heappush(self._sample_heap, element)
   203  
   204        if len(self._sample_heap) > self._sample_size:
   205          temp = heapq.heappop(self._sample_heap)
   206          self._sample_set.remove(temp)
   207          self._min_hash = self._sample_heap[0]
   208        elif element < self._min_hash:
   209          self._min_hash = element
   210      return True
   211  
   212    def get_estimate(self):
   213      """
   214      :return: estimation count of unique values
   215  
   216      If heap size is smaller than sample size, just return heap size.
   217      Otherwise, takes into account the possibility of hash collisions,
   218      which become more likely than not for 2^32 distinct elements.
   219      Note that log(1+x) ~ x for small x, so for sampleSize << maxHash
   220      log(1 - sample_size/sample_space) / log(1 - 1/sample_space) ~ sample_size
   221      and hence estimate ~ sample_size * hash_space / sample_space
   222      as one would expect.
   223  
   224      Given sample_size / sample_space = est / hash_space
   225      est = sample_size * hash_space / sample_space
   226  
   227      Given above sample_size approximate,
   228      est = log1p(-sample_size/sample_space) / log1p(-1/sample_space)
   229        * hash_space / sample_space
   230      """
   231      if len(self._sample_heap) < self._sample_size:
   232        return len(self._sample_heap)
   233      else:
   234        sample_space_size = self._HASH_SPACE_SIZE - 1.0 * self._min_hash
   235        est = (
   236            math.log1p(-self._sample_size / sample_space_size) /
   237            math.log1p(-1 / sample_space_size) * self._HASH_SPACE_SIZE /
   238            sample_space_size)
   239        return round(est)
   240  
   241  
   242  class ApproximateUniqueCombineFn(CombineFn):
   243    """
   244    ApproximateUniqueCombineFn computes an estimate of the number of
   245    unique values that were combined.
   246    """
   247    def __init__(self, sample_size, coder):
   248      self._sample_size = sample_size
   249      coder = coders.typecoders.registry.verify_deterministic(
   250          coder, 'ApproximateUniqueCombineFn')
   251  
   252      self._coder = coder
   253      self._hash_fn = _get_default_hash_fn()
   254  
   255    def create_accumulator(self, *args, **kwargs):
   256      return _LargestUnique(self._sample_size)
   257  
   258    def add_input(self, accumulator, element, *args, **kwargs):
   259      try:
   260        hashed_value = self._hash_fn(self._coder.encode(element))
   261        accumulator.add(hashed_value)
   262        return accumulator
   263      except Exception as e:
   264        raise RuntimeError("Runtime exception: %s" % e)
   265  
   266    # created an issue https://github.com/apache/beam/issues/19459 to speed up
   267    # merge process.
   268    def merge_accumulators(self, accumulators, *args, **kwargs):
   269      merged_accumulator = self.create_accumulator()
   270      for accumulator in accumulators:
   271        for i in accumulator._sample_heap:
   272          merged_accumulator.add(i)
   273  
   274      return merged_accumulator
   275  
   276    @staticmethod
   277    def extract_output(accumulator):
   278      return accumulator.get_estimate()
   279  
   280    def display_data(self):
   281      return {'sample_size': self._sample_size}
   282  
   283  
   284  class ApproximateQuantiles(object):
   285    """
   286    PTransform for getting the idea of data distribution using approximate N-tile
   287    (e.g. quartiles, percentiles etc.) either globally or per-key.
   288  
   289    Examples:
   290  
   291      in: list(range(101)), num_quantiles=5
   292  
   293      out: [0, 25, 50, 75, 100]
   294  
   295      in: [(i, 1 if i<10 else 1e-5) for i in range(101)], num_quantiles=5,
   296        weighted=True
   297  
   298      out: [0, 2, 5, 7, 100]
   299  
   300      in: [list(range(10)), ..., list(range(90, 101))], num_quantiles=5,
   301        input_batched=True
   302  
   303      out: [0, 25, 50, 75, 100]
   304  
   305      in: [(list(range(10)), [1]*10), (list(range(10)), [0]*10), ...,
   306        (list(range(90, 101)), [0]*11)], num_quantiles=5, input_batched=True,
   307        weighted=True
   308  
   309      out: [0, 2, 5, 7, 100]
   310    """
   311    @staticmethod
   312    def _display_data(num_quantiles, key, reverse, weighted, input_batched):
   313      return {
   314          'num_quantiles': DisplayDataItem(num_quantiles, label='Quantile Count'),
   315          'key': DisplayDataItem(
   316              key.__name__
   317              if hasattr(key, '__name__') else key.__class__.__name__,
   318              label='Record Comparer Key'),
   319          'reverse': DisplayDataItem(str(reverse), label='Is Reversed'),
   320          'weighted': DisplayDataItem(str(weighted), label='Is Weighted'),
   321          'input_batched': DisplayDataItem(
   322              str(input_batched), label='Is Input Batched'),
   323      }
   324  
   325    @typehints.with_input_types(
   326        typehints.Union[typing.Sequence[T], typing.Tuple[T, float]])
   327    @typehints.with_output_types(typing.List[T])
   328    class Globally(PTransform):
   329      """
   330      PTransform takes PCollection and returns a list whose single value is
   331      approximate N-tiles of the input collection globally.
   332  
   333      Args:
   334        num_quantiles: number of elements in the resulting quantiles values list.
   335        key: (optional) Key is  a mapping of elements to a comparable key, similar
   336          to the key argument of Python's sorting methods.
   337        reverse: (optional) whether to order things smallest to largest, rather
   338          than largest to smallest.
   339        weighted: (optional) if set to True, the transform returns weighted
   340          quantiles. The input PCollection is then expected to contain tuples of
   341          input values with the corresponding weight.
   342        input_batched: (optional) if set to True, the transform expects each
   343          element of input PCollection to be a batch, which is a list of elements
   344          for non-weighted case and a tuple of lists of elements and weights for
   345          weighted. Provides a way to accumulate multiple elements at a time more
   346          efficiently.
   347      """
   348      def __init__(
   349          self,
   350          num_quantiles,
   351          key=None,
   352          reverse=False,
   353          weighted=False,
   354          input_batched=False):
   355        self._num_quantiles = num_quantiles
   356        self._key = key
   357        self._reverse = reverse
   358        self._weighted = weighted
   359        self._input_batched = input_batched
   360  
   361      def expand(self, pcoll):
   362        return pcoll | CombineGlobally(
   363            ApproximateQuantilesCombineFn.create(
   364                num_quantiles=self._num_quantiles,
   365                key=self._key,
   366                reverse=self._reverse,
   367                weighted=self._weighted,
   368                input_batched=self._input_batched))
   369  
   370      def display_data(self):
   371        return ApproximateQuantiles._display_data(
   372            num_quantiles=self._num_quantiles,
   373            key=self._key,
   374            reverse=self._reverse,
   375            weighted=self._weighted,
   376            input_batched=self._input_batched)
   377  
   378    @typehints.with_input_types(
   379        typehints.Union[typing.Tuple[K, V],
   380                        typing.Tuple[K, typing.Tuple[V, float]]])
   381    @typehints.with_output_types(typing.Tuple[K, typing.List[V]])
   382    class PerKey(PTransform):
   383      """
   384      PTransform takes PCollection of KV and returns a list based on each key
   385      whose single value is list of approximate N-tiles of the input element of
   386      the key.
   387  
   388      Args:
   389        num_quantiles: number of elements in the resulting quantiles values list.
   390        key: (optional) Key is  a mapping of elements to a comparable key, similar
   391          to the key argument of Python's sorting methods.
   392        reverse: (optional) whether to order things smallest to largest, rather
   393          than largest to smallest.
   394        weighted: (optional) if set to True, the transform returns weighted
   395          quantiles. The input PCollection is then expected to contain tuples of
   396          input values with the corresponding weight.
   397        input_batched: (optional) if set to True, the transform expects each
   398          element of input PCollection to be a batch, which is a list of elements
   399          for non-weighted case and a tuple of lists of elements and weights for
   400          weighted. Provides a way to accumulate multiple elements at a time more
   401          efficiently.
   402      """
   403      def __init__(
   404          self,
   405          num_quantiles,
   406          key=None,
   407          reverse=False,
   408          weighted=False,
   409          input_batched=False):
   410        self._num_quantiles = num_quantiles
   411        self._key = key
   412        self._reverse = reverse
   413        self._weighted = weighted
   414        self._input_batched = input_batched
   415  
   416      def expand(self, pcoll):
   417        return pcoll | CombinePerKey(
   418            ApproximateQuantilesCombineFn.create(
   419                num_quantiles=self._num_quantiles,
   420                key=self._key,
   421                reverse=self._reverse,
   422                weighted=self._weighted,
   423                input_batched=self._input_batched))
   424  
   425      def display_data(self):
   426        return ApproximateQuantiles._display_data(
   427            num_quantiles=self._num_quantiles,
   428            key=self._key,
   429            reverse=self._reverse,
   430            weighted=self._weighted,
   431            input_batched=self._input_batched)
   432  
   433  
   434  class _QuantileSpec(object):
   435    """Quantiles computation specifications."""
   436    def __init__(self, buffer_size, num_buffers, weighted, key, reverse):
   437      # type: (int, int, bool, Any, bool) -> None
   438      self.buffer_size = buffer_size
   439      self.num_buffers = num_buffers
   440      self.weighted = weighted
   441      self.key = key
   442      self.reverse = reverse
   443  
   444      # Used to sort tuples of values and weights.
   445      self.weighted_key = None if key is None else (lambda x: key(x[0]))
   446  
   447      # Used to compare values.
   448      if reverse and key is None:
   449        self.less_than = lambda a, b: a > b
   450      elif reverse:
   451        self.less_than = lambda a, b: key(a) > key(b)
   452      elif key is None:
   453        self.less_than = lambda a, b: a < b
   454      else:
   455        self.less_than = lambda a, b: key(a) < key(b)
   456  
   457    def get_argsort_key(self, elements):
   458      # type: (List) -> Callable[[int], Any]
   459  
   460      """Returns a key for sorting indices of elements by element's value."""
   461      if self.key is None:
   462        return elements.__getitem__
   463      else:
   464        return lambda idx: self.key(elements[idx])
   465  
   466    def __reduce__(self):
   467      return (
   468          self.__class__,
   469          (
   470              self.buffer_size,
   471              self.num_buffers,
   472              self.weighted,
   473              self.key,
   474              self.reverse))
   475  
   476  
   477  class _QuantileBuffer(object):
   478    """A single buffer in the sense of the referenced algorithm.
   479    (see http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.6.6513&rep=rep1
   480    &type=pdf and ApproximateQuantilesCombineFn for further information)"""
   481    def __init__(
   482        self, elements, weights, weighted, level=0, min_val=None, max_val=None):
   483      # type: (List, List, bool, int, Any, Any) -> None
   484      self.elements = elements
   485      # In non-weighted case weights contains a single element representing weight
   486      # of the buffer in the sense of the original algorithm. In weighted case,
   487      # it stores weights of individual elements.
   488      self.weights = weights
   489      self.weighted = weighted
   490      self.level = level
   491      if min_val is None or max_val is None:
   492        # Buffer is always initialized with sorted elements.
   493        self.min_val = elements[0]
   494        self.max_val = elements[-1]
   495      else:
   496        # Note that collapsed buffer may not contain min and max in the list of
   497        # elements.
   498        self.min_val = min_val
   499        self.max_val = max_val
   500  
   501    def __iter__(self):
   502      return zip(
   503          self.elements,
   504          self.weights if self.weighted else itertools.repeat(self.weights[0]))
   505  
   506    def __lt__(self, other):
   507      return self.level < other.level
   508  
   509  
   510  class _QuantileState(object):
   511    """
   512    Compact summarization of a collection on which quantiles can be estimated.
   513    """
   514    def __init__(self, unbuffered_elements, unbuffered_weights, buffers, spec):
   515      # type: (List, List, List[_QuantileBuffer], _QuantileSpec) -> None
   516      self.buffers = buffers
   517      self.spec = spec
   518      if spec.weighted:
   519        self.add_unbuffered = self._add_unbuffered_weighted
   520      else:
   521        self.add_unbuffered = self._add_unbuffered
   522  
   523      # The algorithm requires that the manipulated buffers always be filled to
   524      # capacity to perform the collapse operation. This operation can be extended
   525      # to buffers of varying sizes by introducing the notion of fractional
   526      # weights, but it's easier to simply combine the remainders from all shards
   527      # into new, full buffers and then take them into account when computing the
   528      # final output.
   529      self.unbuffered_elements = unbuffered_elements
   530      self.unbuffered_weights = unbuffered_weights
   531  
   532    # This is needed for pickling to work when Cythonization is enabled.
   533    def __reduce__(self):
   534      return (
   535          self.__class__,
   536          (
   537              self.unbuffered_elements,
   538              self.unbuffered_weights,
   539              self.buffers,
   540              self.spec))
   541  
   542    def is_empty(self):
   543      # type: () -> bool
   544  
   545      """Check if the buffered & unbuffered elements are empty or not."""
   546      return not self.unbuffered_elements and not self.buffers
   547  
   548    def _add_unbuffered(self, elements, offset_fn):
   549      # type: (List, Any) -> None
   550  
   551      """
   552      Add elements to the unbuffered list, creating new buffers and
   553      collapsing if needed.
   554      """
   555      self.unbuffered_elements.extend(elements)
   556      num_new_buffers = len(self.unbuffered_elements) // self.spec.buffer_size
   557      for idx in range(num_new_buffers):
   558        to_buffer = sorted(
   559            self.unbuffered_elements[idx * self.spec.buffer_size:(idx + 1) *
   560                                     self.spec.buffer_size],
   561            key=self.spec.key,
   562            reverse=self.spec.reverse)
   563        heapq.heappush(
   564            self.buffers,
   565            _QuantileBuffer(elements=to_buffer, weights=[1], weighted=False))
   566  
   567      if num_new_buffers > 0:
   568        self.unbuffered_elements = self.unbuffered_elements[num_new_buffers *
   569                                                            self.spec.
   570                                                            buffer_size:]
   571  
   572      self.collapse_if_needed(offset_fn)
   573  
   574    def _add_unbuffered_weighted(self, elements, offset_fn):
   575      # type: (List, Any) -> None
   576  
   577      """
   578      Add elements with weights to the unbuffered list, creating new buffers and
   579      collapsing if needed.
   580      """
   581      if len(elements) == 1:
   582        self.unbuffered_elements.append(elements[0][0])
   583        self.unbuffered_weights.append(elements[0][1])
   584      else:
   585        self.unbuffered_elements.extend(elements[0])
   586        self.unbuffered_weights.extend(elements[1])
   587      num_new_buffers = len(self.unbuffered_elements) // self.spec.buffer_size
   588      argsort_key = self.spec.get_argsort_key(self.unbuffered_elements)
   589      for idx in range(num_new_buffers):
   590        argsort = sorted(
   591            range(idx * self.spec.buffer_size, (idx + 1) * self.spec.buffer_size),
   592            key=argsort_key,
   593            reverse=self.spec.reverse)
   594        elements_to_buffer = [self.unbuffered_elements[idx] for idx in argsort]
   595        weights_to_buffer = [self.unbuffered_weights[idx] for idx in argsort]
   596        heapq.heappush(
   597            self.buffers,
   598            _QuantileBuffer(
   599                elements=elements_to_buffer,
   600                weights=weights_to_buffer,
   601                weighted=True))
   602  
   603      if num_new_buffers > 0:
   604        self.unbuffered_elements = self.unbuffered_elements[num_new_buffers *
   605                                                            self.spec.
   606                                                            buffer_size:]
   607        self.unbuffered_weights = self.unbuffered_weights[num_new_buffers *
   608                                                          self.spec.buffer_size:]
   609  
   610      self.collapse_if_needed(offset_fn)
   611  
   612    def finalize(self):
   613      # type: () -> None
   614  
   615      """
   616      Creates a new buffer using all unbuffered elements. Called before
   617      extracting an output. Note that the buffer doesn't have to be put in a
   618      proper position since _collapse is not going to be called after.
   619      """
   620      if self.unbuffered_elements and self.spec.weighted:
   621        argsort_key = self.spec.get_argsort_key(self.unbuffered_elements)
   622        argsort = sorted(
   623            range(len(self.unbuffered_elements)),
   624            key=argsort_key,
   625            reverse=self.spec.reverse)
   626        self.unbuffered_elements = [
   627            self.unbuffered_elements[idx] for idx in argsort
   628        ]
   629        self.unbuffered_weights = [
   630            self.unbuffered_weights[idx] for idx in argsort
   631        ]
   632        self.buffers.append(
   633            _QuantileBuffer(
   634                self.unbuffered_elements, self.unbuffered_weights, weighted=True))
   635        self.unbuffered_weights = []
   636      elif self.unbuffered_elements:
   637        self.unbuffered_elements.sort(
   638            key=self.spec.key, reverse=self.spec.reverse)
   639        self.buffers.append(
   640            _QuantileBuffer(
   641                self.unbuffered_elements, weights=[1], weighted=False))
   642      self.unbuffered_elements = []
   643  
   644    def collapse_if_needed(self, offset_fn):
   645      # type: (Any) -> None
   646  
   647      """
   648      Checks if summary has too many buffers and collapses some of them until the
   649      limit is restored.
   650      """
   651      while len(self.buffers) > self.spec.num_buffers:
   652        to_collapse = [heapq.heappop(self.buffers), heapq.heappop(self.buffers)]
   653        min_level = to_collapse[1].level
   654  
   655        while self.buffers and self.buffers[0].level <= min_level:
   656          to_collapse.append(heapq.heappop(self.buffers))
   657  
   658        heapq.heappush(self.buffers, _collapse(to_collapse, offset_fn, self.spec))
   659  
   660  
   661  def _collapse(buffers, offset_fn, spec):
   662    # type: (List[_QuantileBuffer], Any, _QuantileSpec) -> _QuantileBuffer
   663  
   664    """
   665    Approximates elements from multiple buffers and produces a single buffer.
   666    """
   667    new_level = 0
   668    new_weight = 0
   669    for buffer in buffers:
   670      # As presented in the paper, there should always be at least two
   671      # buffers of the same (minimal) level to collapse, but it is possible
   672      # to violate this condition when combining buffers from independently
   673      # computed shards. If they differ we take the max.
   674      new_level = max([new_level, buffer.level + 1])
   675      new_weight = new_weight + sum(buffer.weights)
   676    if spec.weighted:
   677      step = new_weight / (spec.buffer_size - 1)
   678      offset = new_weight / (2 * spec.buffer_size)
   679    else:
   680      step = new_weight
   681      offset = offset_fn(new_weight)
   682    new_elements, new_weights, min_val, max_val = \
   683        _interpolate(buffers, spec.buffer_size, step, offset, spec)
   684    if not spec.weighted:
   685      new_weights = [new_weight]
   686    return _QuantileBuffer(
   687        new_elements, new_weights, spec.weighted, new_level, min_val, max_val)
   688  
   689  
   690  def _interpolate(buffers, count, step, offset, spec):
   691    # type: (List[_QuantileBuffer], int, float, float, _QuantileSpec) -> Tuple[List, List, Any, Any]
   692  
   693    """
   694    Emulates taking the ordered union of all elements in buffers, repeated
   695    according to their weight, and picking out the (k * step + offset)-th elements
   696    of this list for `0 <= k < count`.
   697    """
   698    buffer_iterators = []
   699    min_val = buffers[0].min_val
   700    max_val = buffers[0].max_val
   701    for buffer in buffers:
   702      # Calculate extreme values for the union of buffers.
   703      min_val = buffer.min_val if spec.less_than(
   704          buffer.min_val, min_val) else min_val
   705      max_val = buffer.max_val if spec.less_than(
   706          max_val, buffer.max_val) else max_val
   707      buffer_iterators.append(iter(buffer))
   708  
   709    # Note that `heapq.merge` can also be used here since the buffers are sorted.
   710    # In practice, however, `sorted` uses natural order in the union and
   711    # significantly outperforms `heapq.merge`.
   712    sorted_elements = sorted(
   713        itertools.chain.from_iterable(buffer_iterators),
   714        key=spec.weighted_key,
   715        reverse=spec.reverse)
   716  
   717    if not spec.weighted:
   718      # If all buffers have the same weight, then quantiles' indices are evenly
   719      # distributed over a range [0, len(sorted_elements)].
   720      buffers_have_same_weight = True
   721      weight = buffers[0].weights[0]
   722      for buffer in buffers:
   723        if buffer.weights[0] != weight:
   724          buffers_have_same_weight = False
   725          break
   726      if buffers_have_same_weight:
   727        offset = offset / weight
   728        step = step / weight
   729        max_idx = len(sorted_elements) - 1
   730        result = [
   731            sorted_elements[min(int(j * step + offset), max_idx)][0]
   732            for j in range(count)
   733        ]
   734        return result, [], min_val, max_val
   735  
   736    sorted_elements_iter = iter(sorted_elements)
   737    weighted_element = next(sorted_elements_iter)
   738    new_elements = []
   739    new_weights = []
   740    j = 0
   741    current_weight = weighted_element[1]
   742    previous_weight = 0
   743    while j < count:
   744      target_weight = j * step + offset
   745      j += 1
   746      try:
   747        while current_weight <= target_weight:
   748          weighted_element = next(sorted_elements_iter)
   749          current_weight += weighted_element[1]
   750      except StopIteration:
   751        pass
   752      new_elements.append(weighted_element[0])
   753      if spec.weighted:
   754        new_weights.append(current_weight - previous_weight)
   755        previous_weight = current_weight
   756  
   757    return new_elements, new_weights, min_val, max_val
   758  
   759  
   760  class ApproximateQuantilesCombineFn(CombineFn):
   761    """
   762    This combiner gives an idea of the distribution of a collection of values
   763    using approximate N-tiles. The output of this combiner is the list of size of
   764    the number of quantiles (num_quantiles), containing the input values of the
   765    minimum value item of the list, the intermediate values (n-tiles) and the
   766    maximum value item of the list, in the sort order provided via key (similar
   767    to the key argument of Python's sorting methods).
   768  
   769    If there are fewer values to combine than the number of quantile
   770    (num_quantiles), then the resulting list will contain all the values being
   771    combined, in sorted order.
   772  
   773    If no `key` is provided, then the results are sorted in the natural order.
   774  
   775    To evaluate the quantiles, we use the "New Algorithm" described here:
   776  
   777    [MRL98] Manku, Rajagopalan & Lindsay, "Approximate Medians and other
   778    Quantiles in One Pass and with Limited Memory", Proc. 1998 ACM SIGMOD,
   779    Vol 27, No 2, p 426-435, June 1998.
   780    http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.6.6513&rep=rep1
   781    &type=pdf
   782  
   783    Note that the weighted quantiles are evaluated using a generalized version of
   784    the algorithm referenced in the paper.
   785  
   786    The default error bound is (1 / num_quantiles) for uniformly distributed data
   787    and min(1e-2, 1 / num_quantiles) for weighted case, though in practice the
   788    accuracy tends to be much better.
   789  
   790    Args:
   791      num_quantiles: Number of quantiles to produce. It is the size of the final
   792        output list, including the mininum and maximum value items.
   793      buffer_size: The size of the buffers, corresponding to k in the referenced
   794        paper.
   795      num_buffers: The number of buffers, corresponding to b in the referenced
   796        paper.
   797      key: (optional) Key is a mapping of elements to a comparable key, similar
   798        to the key argument of Python's sorting methods.
   799      reverse: (optional) whether to order things smallest to largest, rather
   800        than largest to smallest.
   801      weighted: (optional) if set to True, the combiner produces weighted
   802        quantiles. The input elements are then expected to be tuples of input
   803        values with the corresponding weight.
   804      input_batched: (optional) if set to True, inputs are expected to be batches
   805        of elements.
   806    """
   807  
   808    # For alternating between biasing up and down in the above even weight
   809    # collapse operation.
   810    _offset_jitter = 0
   811  
   812    # The cost (in time and space) to compute quantiles to a given accuracy is a
   813    # function of the total number of elements in the data set. If an estimate is
   814    # not known or specified, we use this as an upper bound. If this is too low,
   815    # errors may exceed the requested tolerance; if too high, efficiency may be
   816    # non-optimal. The impact is logarithmic with respect to this value, so this
   817    # default should be fine for most uses.
   818    _MAX_NUM_ELEMENTS = 1e9
   819    _qs = None  # type: _QuantileState
   820  
   821    def __init__(
   822        self,
   823        num_quantiles,  # type: int
   824        buffer_size,  # type: int
   825        num_buffers,  # type: int
   826        key=None,
   827        reverse=False,
   828        weighted=False,
   829        input_batched=False):
   830      self._num_quantiles = num_quantiles
   831      self._spec = _QuantileSpec(buffer_size, num_buffers, weighted, key, reverse)
   832      self._input_batched = input_batched
   833      if self._input_batched:
   834        setattr(self, 'add_input', self._add_inputs)
   835  
   836    def __reduce__(self):
   837      return (
   838          self.__class__,
   839          (
   840              self._num_quantiles,
   841              self._spec.buffer_size,
   842              self._spec.num_buffers,
   843              self._spec.key,
   844              self._spec.reverse,
   845              self._spec.weighted,
   846              self._input_batched))
   847  
   848    @classmethod
   849    def create(
   850        cls,
   851        num_quantiles,  # type: int
   852        epsilon=None,
   853        max_num_elements=None,
   854        key=None,
   855        reverse=False,
   856        weighted=False,
   857        input_batched=False):
   858      # type: (...) -> ApproximateQuantilesCombineFn
   859  
   860      """
   861      Creates an approximate quantiles combiner with the given key and desired
   862      number of quantiles.
   863  
   864      Args:
   865        num_quantiles: Number of quantiles to produce. It is the size of the
   866          final output list, including the mininum and maximum value items.
   867        epsilon: (optional) The default error bound is `epsilon`, which holds as
   868          long as the number of elements is less than `_MAX_NUM_ELEMENTS`.
   869          Specifically, if one considers the input as a sorted list x_1, ...,
   870          x_N, then the distance between each exact quantile x_c and its
   871          approximation x_c' is bounded by `|c - c'| < epsilon * N`. Note that
   872          these errors are worst-case scenarios. In practice the accuracy tends
   873          to be much better.
   874        max_num_elements: (optional) The cost (in time and space) to compute
   875          quantiles to a given accuracy is a function of the total number of
   876          elements in the data set.
   877        key: (optional) Key is a mapping of elements to a comparable key, similar
   878          to the key argument of Python's sorting methods.
   879        reverse: (optional) whether to order things smallest to largest, rather
   880          than largest to smallest.
   881        weighted: (optional) if set to True, the combiner produces weighted
   882          quantiles. The input elements are then expected to be tuples of values
   883          with the corresponding weight.
   884        input_batched: (optional) if set to True, inputs are expected to be
   885          batches of elements.
   886      """
   887      max_num_elements = max_num_elements or cls._MAX_NUM_ELEMENTS
   888      if not epsilon:
   889        epsilon = min(1e-2, 1.0 / num_quantiles) \
   890          if weighted else (1.0 / num_quantiles)
   891      # Note that calculation of the buffer size and the number of buffers here
   892      # is based on technique used in the Munro-Paterson algorithm. Switching to
   893      # the logic used in the "New Algorithm" may result in memory savings since
   894      # it results in lower values for b and k in practice.
   895      b = 2
   896      while (b - 2) * (1 << (b - 2)) < epsilon * max_num_elements:
   897        b = b + 1
   898      b = b - 1
   899      k = max(2, int(math.ceil(max_num_elements / float(1 << (b - 1)))))
   900      return cls(
   901          num_quantiles=num_quantiles,
   902          buffer_size=k,
   903          num_buffers=b,
   904          key=key,
   905          reverse=reverse,
   906          weighted=weighted,
   907          input_batched=input_batched)
   908  
   909    def _offset(self, new_weight):
   910      # type: (int) -> float
   911  
   912      """
   913      If the weight is even, we must round up or down. Alternate between these
   914      two options to avoid a bias.
   915      """
   916      if new_weight % 2 == 1:
   917        return (new_weight + 1) / 2
   918      else:
   919        self._offset_jitter = 2 - self._offset_jitter
   920        return (new_weight + self._offset_jitter) / 2
   921  
   922    # TODO(https://github.com/apache/beam/issues/19737): Signature incompatible
   923    # with supertype
   924    def create_accumulator(self):  # type: ignore[override]
   925      # type: () -> _QuantileState
   926      self._qs = _QuantileState(
   927          unbuffered_elements=[],
   928          unbuffered_weights=[],
   929          buffers=[],
   930          spec=self._spec)
   931      return self._qs
   932  
   933    def add_input(self, quantile_state, element):
   934      """
   935      Add a new element to the collection being summarized by quantile state.
   936      """
   937      quantile_state.add_unbuffered([element], self._offset)
   938      return quantile_state
   939  
   940    def _add_inputs(self, quantile_state, elements):
   941      # type: (_QuantileState, List) -> _QuantileState
   942  
   943      """
   944      Add a batch of elements to the collection being summarized by quantile
   945      state.
   946      """
   947      if len(elements) == 0:
   948        return quantile_state
   949      quantile_state.add_unbuffered(elements, self._offset)
   950      return quantile_state
   951  
   952    def merge_accumulators(self, accumulators):
   953      """Merges all the accumulators (quantile state) as one."""
   954      qs = self.create_accumulator()
   955      for accumulator in accumulators:
   956        if accumulator.is_empty():
   957          continue
   958        if self._spec.weighted:
   959          qs.add_unbuffered(
   960              [accumulator.unbuffered_elements, accumulator.unbuffered_weights],
   961              self._offset)
   962        else:
   963          qs.add_unbuffered(accumulator.unbuffered_elements, self._offset)
   964  
   965        qs.buffers.extend(accumulator.buffers)
   966      heapq.heapify(qs.buffers)
   967      qs.collapse_if_needed(self._offset)
   968      return qs
   969  
   970    def extract_output(self, accumulator):
   971      """
   972      Outputs num_quantiles elements consisting of the minimum, maximum and
   973      num_quantiles - 2 evenly spaced intermediate elements. Returns the empty
   974      list if no elements have been added.
   975      """
   976      if accumulator.is_empty():
   977        return []
   978      accumulator.finalize()
   979      all_elems = accumulator.buffers
   980      total_weight = 0
   981      if self._spec.weighted:
   982        for buffer_elem in all_elems:
   983          total_weight += sum(buffer_elem.weights)
   984      else:
   985        for buffer_elem in all_elems:
   986          total_weight += len(buffer_elem.elements) * buffer_elem.weights[0]
   987  
   988      step = total_weight / (self._num_quantiles - 1)
   989      offset = (total_weight - 1) / (self._num_quantiles - 1)
   990  
   991      quantiles, _, min_val, max_val = \
   992          _interpolate(all_elems, self._num_quantiles - 2, step, offset,
   993                       self._spec)
   994  
   995      return [min_val] + quantiles + [max_val]