github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/transforms/stats_test.py (about)

     1  # -*- coding: utf-8 -*-
     2  #
     3  # Licensed to the Apache Software Foundation (ASF) under one or more
     4  # contributor license agreements.  See the NOTICE file distributed with
     5  # this work for additional information regarding copyright ownership.
     6  # The ASF licenses this file to You under the Apache License, Version 2.0
     7  # (the "License"); you may not use this file except in compliance with
     8  # the License.  You may obtain a copy of the License at
     9  #
    10  #    http://www.apache.org/licenses/LICENSE-2.0
    11  #
    12  # Unless required by applicable law or agreed to in writing, software
    13  # distributed under the License is distributed on an "AS IS" BASIS,
    14  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    15  # See the License for the specific language governing permissions and
    16  # limitations under the License.
    17  #
    18  
    19  # pytype: skip-file
    20  
    21  import math
    22  import random
    23  import sys
    24  import unittest
    25  from collections import defaultdict
    26  
    27  import hamcrest as hc
    28  from parameterized import parameterized
    29  from parameterized import parameterized_class
    30  
    31  import apache_beam as beam
    32  from apache_beam.coders import coders
    33  from apache_beam.testing.test_pipeline import TestPipeline
    34  from apache_beam.testing.util import BeamAssertException
    35  from apache_beam.testing.util import assert_that
    36  from apache_beam.testing.util import equal_to
    37  from apache_beam.transforms.core import Create
    38  from apache_beam.transforms.display import DisplayData
    39  from apache_beam.transforms.display_test import DisplayDataItemMatcher
    40  from apache_beam.transforms.stats import ApproximateQuantilesCombineFn
    41  from apache_beam.transforms.stats import ApproximateUniqueCombineFn
    42  
    43  try:
    44    import mmh3
    45    mmh3_options = [(mmh3, ), (None, )]
    46  except ImportError:
    47    mmh3_options = [(None, )]
    48  
    49  
    50  @parameterized_class(('mmh3_option', ), mmh3_options)
    51  class ApproximateUniqueTest(unittest.TestCase):
    52    """Unit tests for ApproximateUnique.Globally and ApproximateUnique.PerKey."""
    53    random.seed(0)
    54  
    55    def setUp(self):
    56      sys.modules['mmh3'] = self.mmh3_option
    57  
    58    @parameterized.expand([
    59        (
    60            'small_population_by_size',
    61            list(range(30)),
    62            32,
    63            None,
    64            'assert:global_by_sample_size_with_small_population'),
    65        (
    66            'large_population_by_size',
    67            list(range(100)),
    68            16,
    69            None,
    70            'assert:global_by_sample_size_with_large_population'),
    71        (
    72            'with_duplicates_by_size', [10] * 50 + [20] * 50,
    73            30,
    74            None,
    75            'assert:global_by_sample_size_with_duplicates'),
    76        (
    77            'small_population_by_error',
    78            list(range(30)),
    79            None,
    80            0.3,
    81            'assert:global_by_error_with_small_population'),
    82        (
    83            'large_population_by_error',
    84            [random.randint(1, 1000) for _ in range(500)],
    85            None,
    86            0.1,
    87            'assert:global_by_error_with_large_population'),
    88    ])
    89    def test_approximate_unique_global(
    90        self, name, test_input, sample_size, est_error, label):
    91      # check that only either sample_size or est_error is not None
    92      assert bool(sample_size) != bool(est_error)
    93      if sample_size:
    94        error = 2 / math.sqrt(sample_size)
    95      else:
    96        error = est_error
    97      random.shuffle(test_input)
    98      actual_count = len(set(test_input))
    99  
   100      with TestPipeline() as pipeline:
   101        result = (
   102            pipeline
   103            | 'create' >> beam.Create(test_input)
   104            | 'get_estimate' >> beam.ApproximateUnique.Globally(
   105                size=sample_size, error=est_error)
   106            | 'compare' >> beam.FlatMap(
   107                lambda x: [abs(x - actual_count) * 1.0 / actual_count <= error]))
   108  
   109        assert_that(result, equal_to([True]), label=label)
   110  
   111    @parameterized.expand([
   112        ('by_size', 20, None, 'assert:unique_perkey_by_sample_size'),
   113        ('by_error', None, 0.02, 'assert:unique_perkey_by_error')
   114    ])
   115    def test_approximate_unique_perkey(self, name, sample_size, est_error, label):
   116      # check that only either sample_size or est_error is set
   117      assert bool(sample_size) != bool(est_error)
   118      if sample_size:
   119        error = 2 / math.sqrt(sample_size)
   120      else:
   121        error = est_error
   122  
   123      test_input = [(8, 73), (6, 724), (7, 70), (1, 576), (10, 120), (2, 662),
   124                    (7, 115), (3, 731), (6, 340), (6, 623), (1, 74), (9, 280),
   125                    (8, 298), (6, 440), (10, 243), (1, 125), (9, 754), (8, 833),
   126                    (9, 751), (4, 818), (6, 176), (9, 253), (2, 721), (8, 936),
   127                    (3, 691), (10, 685), (1, 69), (3, 155), (8, 86), (5, 693),
   128                    (2, 809), (4, 723), (8, 102), (9, 707), (8, 558), (4, 537),
   129                    (5, 371), (7, 432), (2, 51), (10, 397)]
   130      actual_count_dict = defaultdict(set)
   131      for (x, y) in test_input:
   132        actual_count_dict[x].add(y)
   133  
   134      with TestPipeline() as pipeline:
   135        result = (
   136            pipeline
   137            | 'create' >> beam.Create(test_input)
   138            | 'get_estimate' >> beam.ApproximateUnique.PerKey(
   139                size=sample_size, error=est_error)
   140            | 'compare' >> beam.FlatMap(
   141                lambda x: [
   142                    abs(x[1] - len(actual_count_dict[x[0]])) * 1.0 / len(
   143                        actual_count_dict[x[0]]) <= error
   144                ]))
   145  
   146        assert_that(
   147            result, equal_to([True] * len(actual_count_dict)), label=label)
   148  
   149    @parameterized.expand([
   150        (
   151            'invalid_input_size',
   152            list(range(30)),
   153            10,
   154            None,
   155            beam.ApproximateUnique._INPUT_SIZE_ERR_MSG % 10),
   156        (
   157            'invalid_type_size',
   158            list(range(30)),
   159            100.0,
   160            None,
   161            beam.ApproximateUnique._INPUT_SIZE_ERR_MSG % 100.0),
   162        (
   163            'invalid_small_error',
   164            list(range(30)),
   165            None,
   166            0.0,
   167            beam.ApproximateUnique._INPUT_ERROR_ERR_MSG % 0.0),
   168        (
   169            'invalid_big_error',
   170            list(range(30)),
   171            None,
   172            0.6,
   173            beam.ApproximateUnique._INPUT_ERROR_ERR_MSG % 0.6),
   174        (
   175            'no_input',
   176            list(range(30)),
   177            None,
   178            None,
   179            beam.ApproximateUnique._NO_VALUE_ERR_MSG),
   180        (
   181            'both_input',
   182            list(range(30)),
   183            30,
   184            0.2,
   185            beam.ApproximateUnique._MULTI_VALUE_ERR_MSG % (30, 0.2)),
   186    ])
   187    def test_approximate_unique_global_value_error(
   188        self, name, test_input, sample_size, est_error, expected_msg):
   189      with self.assertRaises(ValueError) as e:
   190        with TestPipeline() as pipeline:
   191          _ = (
   192              pipeline
   193              | 'create' >> beam.Create(test_input)
   194              | 'get_estimate' >> beam.ApproximateUnique.Globally(
   195                  size=sample_size, error=est_error))
   196  
   197      assert e.exception.args[0] == expected_msg
   198  
   199    def test_approximate_unique_combine_fn_requires_nondeterministic_coder(self):
   200      sample_size = 30
   201      coder = coders.Base64PickleCoder()
   202  
   203      with self.assertRaises(ValueError) as e:
   204        _ = ApproximateUniqueCombineFn(sample_size, coder)
   205  
   206      self.assertRegex(
   207          e.exception.args[0],
   208          'The key coder "Base64PickleCoder" '
   209          'for ApproximateUniqueCombineFn is not deterministic.')
   210  
   211    def test_approximate_unique_combine_fn_requires_compatible_coder(self):
   212      test_input = 'a'
   213      sample_size = 30
   214      coder = coders.FloatCoder()
   215      combine_fn = ApproximateUniqueCombineFn(sample_size, coder)
   216      accumulator = combine_fn.create_accumulator()
   217      with self.assertRaises(RuntimeError) as e:
   218        accumulator = combine_fn.add_input(accumulator, test_input)
   219  
   220      self.assertRegex(e.exception.args[0], 'Runtime exception')
   221  
   222    def test_get_sample_size_from_est_error(self):
   223      # test if get correct sample size from input error.
   224      assert beam.ApproximateUnique._get_sample_size_from_est_error(0.5) == 16
   225      assert beam.ApproximateUnique._get_sample_size_from_est_error(0.4) == 25
   226      assert beam.ApproximateUnique._get_sample_size_from_est_error(0.2) == 100
   227      assert beam.ApproximateUnique._get_sample_size_from_est_error(0.1) == 400
   228      assert beam.ApproximateUnique._get_sample_size_from_est_error(0.05) == 1600
   229      assert beam.ApproximateUnique._get_sample_size_from_est_error(0.01) == 40000
   230  
   231  
   232  class ApproximateQuantilesTest(unittest.TestCase):
   233    _kv_data = [("a", 1), ("a", 2), ("a", 3), ("b", 1), ("b", 10), ("b", 10),
   234                ("b", 100)]
   235    _kv_str_data = [("a", "a"), ("a", "a" * 2), ("a", "a" * 3), ("b", "b"),
   236                    ("b", "b" * 10), ("b", "b" * 10), ("b", "b" * 100)]
   237  
   238    @staticmethod
   239    def _quantiles_matcher(expected):
   240      l = len(expected)
   241  
   242      def assert_true(exp):
   243        if not exp:
   244          raise BeamAssertException('%s Failed assert True' % repr(exp))
   245  
   246      def match(actual):
   247        actual = actual[0]
   248        for i in range(l):
   249          if isinstance(expected[i], list):
   250            assert_true(expected[i][0] <= actual[i] <= expected[i][1])
   251          else:
   252            equal_to([expected[i]])([actual[i]])
   253  
   254      return match
   255  
   256    @staticmethod
   257    def _approx_quantile_generator(size, num_of_quantiles, absoluteError):
   258      quantiles = [0]
   259      k = 1
   260      while k < num_of_quantiles - 1:
   261        expected = (size - 1) * k / (num_of_quantiles - 1)
   262        quantiles.append([expected - absoluteError, expected + absoluteError])
   263        k = k + 1
   264      quantiles.append(size - 1)
   265      return quantiles
   266  
   267    def test_quantiles_globaly(self):
   268      with TestPipeline() as p:
   269        pc = p | Create(list(range(101)))
   270  
   271        quantiles = pc | 'Quantiles globally' >> \
   272                    beam.ApproximateQuantiles.Globally(5)
   273        quantiles_reversed = pc | 'Quantiles globally reversed' >> \
   274                             beam.ApproximateQuantiles.Globally(5, reverse=True)
   275  
   276        assert_that(
   277            quantiles,
   278            equal_to([[0, 25, 50, 75, 100]]),
   279            label='checkQuantilesGlobally')
   280        assert_that(
   281            quantiles_reversed,
   282            equal_to([[100, 75, 50, 25, 0]]),
   283            label='checkReversedQuantiles')
   284  
   285    def test_quantiles_globally_weighted(self):
   286      num_inputs = 1e3
   287      a = -3
   288      b = 3
   289  
   290      # Weighting function coincides with the pdf of the standard normal
   291      # distribution up to a constant. Since 99.7% of the probability mass for
   292      # this pdf is concentrated in the interval [a, b] = [-3, 3], the quantiles
   293      # for a sample from this interval with the given weight function are
   294      # expected to be close to the quantiles of the standard normal distribution.
   295      def weight(x):
   296        return math.exp(-(x**2) / 2)
   297  
   298      input_data = [
   299          (a + (b - a) * i / num_inputs, weight(a + (b - a) * i / num_inputs))
   300          for i in range(int(num_inputs) + 1)
   301      ]
   302      with TestPipeline() as p:
   303        pc = p | Create(input_data)
   304  
   305        weighted_quantiles = pc | "Quantiles globally weighted" >> \
   306                             beam.ApproximateQuantiles.Globally(5, weighted=True)
   307        reversed_weighted_quantiles = (
   308            pc | 'Quantiles globally weighted reversed' >>
   309            beam.ApproximateQuantiles.Globally(5, reverse=True, weighted=True))
   310  
   311        assert_that(
   312            weighted_quantiles,
   313            equal_to([[-3., -0.6720000000000002, 0., 0.6720000000000002, 3.]]),
   314            label="checkWeightedQuantilesGlobally")
   315        assert_that(
   316            reversed_weighted_quantiles,
   317            equal_to([[3., 0.6720000000000002, 0., -0.6720000000000002, -3.]]),
   318            label="checkWeightedReversedQuantilesGlobally")
   319  
   320    def test_quantiles_per_key(self):
   321      with TestPipeline() as p:
   322        data = self._kv_data
   323        pc = p | Create(data)
   324  
   325        per_key = pc | 'Quantiles PerKey' >> beam.ApproximateQuantiles.PerKey(2)
   326        per_key_reversed = (
   327            pc | 'Quantiles PerKey Reversed' >> beam.ApproximateQuantiles.PerKey(
   328                2, reverse=True))
   329  
   330        assert_that(
   331            per_key,
   332            equal_to([('a', [1, 3]), ('b', [1, 100])]),
   333            label='checkQuantilePerKey')
   334        assert_that(
   335            per_key_reversed,
   336            equal_to([('a', [3, 1]), ('b', [100, 1])]),
   337            label='checkReversedQuantilesPerKey')
   338  
   339    def test_quantiles_per_key_weighted(self):
   340      with TestPipeline() as p:
   341        data = [(k, (v, 2.)) for k, v in self._kv_data]
   342        pc = p | Create(data)
   343  
   344        per_key = pc | 'Weighted Quantiles PerKey' >> \
   345                  beam.ApproximateQuantiles.PerKey(2, weighted=True)
   346        per_key_reversed = pc | 'Weighted Quantiles PerKey Reversed' >> \
   347                           beam.ApproximateQuantiles.PerKey(
   348                             2, reverse=True, weighted=True)
   349  
   350        assert_that(
   351            per_key,
   352            equal_to([('a', [1, 3]), ('b', [1, 100])]),
   353            label='checkWeightedQuantilesPerKey')
   354        assert_that(
   355            per_key_reversed,
   356            equal_to([('a', [3, 1]), ('b', [100, 1])]),
   357            label='checkWeightedReversedQuantilesPerKey')
   358  
   359    def test_quantiles_per_key_with_key_argument(self):
   360      with TestPipeline() as p:
   361        data = self._kv_str_data
   362        pc = p | Create(data)
   363  
   364        per_key = pc | 'Per Key' >> beam.ApproximateQuantiles.PerKey(2, key=len)
   365        per_key_reversed = (
   366            pc | 'Per Key Reversed' >> beam.ApproximateQuantiles.PerKey(
   367                2, key=len, reverse=True))
   368  
   369        assert_that(
   370            per_key,
   371            equal_to([('a', ['a', 'a' * 3]), ('b', ['b', 'b' * 100])]),
   372            label='checkPerKey')
   373        assert_that(
   374            per_key_reversed,
   375            equal_to([('a', ['a' * 3, 'a']), ('b', ['b' * 100, 'b'])]),
   376            label='checkPerKeyReversed')
   377  
   378    def test_singleton(self):
   379      with TestPipeline() as p:
   380        data = [389]
   381        pc = p | Create(data)
   382        quantiles = pc | beam.ApproximateQuantiles.Globally(5)
   383        assert_that(quantiles, equal_to([[389, 389, 389, 389, 389]]))
   384  
   385    def test_uneven_quantiles(self):
   386      with TestPipeline() as p:
   387        pc = p | Create(list(range(5000)))
   388        quantiles = pc | beam.ApproximateQuantiles.Globally(37)
   389        approx_quantiles = self._approx_quantile_generator(
   390            size=5000, num_of_quantiles=37, absoluteError=20)
   391        assert_that(quantiles, self._quantiles_matcher(approx_quantiles))
   392  
   393    def test_large_quantiles(self):
   394      with TestPipeline() as p:
   395        pc = p | Create(list(range(10001)))
   396        quantiles = pc | beam.ApproximateQuantiles.Globally(50)
   397        approx_quantiles = self._approx_quantile_generator(
   398            size=10001, num_of_quantiles=50, absoluteError=20)
   399        assert_that(quantiles, self._quantiles_matcher(approx_quantiles))
   400  
   401    def test_random_quantiles(self):
   402      with TestPipeline() as p:
   403        data = list(range(101))
   404        random.shuffle(data)
   405        pc = p | Create(data)
   406        quantiles = pc | beam.ApproximateQuantiles.Globally(5)
   407        assert_that(quantiles, equal_to([[0, 25, 50, 75, 100]]))
   408  
   409    def test_duplicates(self):
   410      y = list(range(101))
   411      data = []
   412      for _ in range(10):
   413        data.extend(y)
   414  
   415      with TestPipeline() as p:
   416        pc = p | Create(data)
   417        quantiles = (
   418            pc | 'Quantiles Globally' >> beam.ApproximateQuantiles.Globally(5))
   419        quantiles_reversed = (
   420            pc | 'Quantiles Reversed' >> beam.ApproximateQuantiles.Globally(
   421                5, reverse=True))
   422  
   423        assert_that(
   424            quantiles,
   425            equal_to([[0, 25, 50, 75, 100]]),
   426            label="checkQuantilesGlobally")
   427        assert_that(
   428            quantiles_reversed,
   429            equal_to([[100, 75, 50, 25, 0]]),
   430            label="checkQuantileReversed")
   431  
   432    def test_lots_of_duplicates(self):
   433      with TestPipeline() as p:
   434        data = [1]
   435        data.extend([2 for _ in range(299)])
   436        data.extend([3 for _ in range(799)])
   437        pc = p | Create(data)
   438        quantiles = pc | beam.ApproximateQuantiles.Globally(5)
   439        assert_that(quantiles, equal_to([[1, 2, 3, 3, 3]]))
   440  
   441    def test_log_distribution(self):
   442      with TestPipeline() as p:
   443        data = [int(math.log(x)) for x in range(1, 1000)]
   444        pc = p | Create(data)
   445        quantiles = pc | beam.ApproximateQuantiles.Globally(5)
   446        assert_that(quantiles, equal_to([[0, 5, 6, 6, 6]]))
   447  
   448    def test_zipfian_distribution(self):
   449      with TestPipeline() as p:
   450        data = []
   451        for i in range(1, 1000):
   452          data.append(int(1000 / i))
   453        pc = p | Create(data)
   454        quantiles = pc | beam.ApproximateQuantiles.Globally(5)
   455        assert_that(quantiles, equal_to([[1, 1, 2, 4, 1000]]))
   456  
   457    def test_alternate_quantiles(self):
   458      data = ["aa", "aaa", "aaaa", "b", "ccccc", "dddd", "zz"]
   459      with TestPipeline() as p:
   460        pc = p | Create(data)
   461  
   462        globally = pc | 'Globally' >> beam.ApproximateQuantiles.Globally(3)
   463        with_key = (
   464            pc |
   465            'Globally with key' >> beam.ApproximateQuantiles.Globally(3, key=len))
   466        key_with_reversed = (
   467            pc | 'Globally with key and reversed' >>
   468            beam.ApproximateQuantiles.Globally(3, key=len, reverse=True))
   469  
   470        assert_that(
   471            globally, equal_to([["aa", "b", "zz"]]), label='checkGlobally')
   472        assert_that(
   473            with_key,
   474            equal_to([["b", "aaa", "ccccc"]]),
   475            label='checkGloballyWithKey')
   476        assert_that(
   477            key_with_reversed,
   478            equal_to([["ccccc", "aaa", "b"]]),
   479            label='checkWithKeyAndReversed')
   480  
   481    def test_batched_quantiles(self):
   482      with TestPipeline() as p:
   483        data = []
   484        for i in range(100):
   485          data.append([(j / 10, abs(j - 500))
   486                       for j in range(i * 10, (i + 1) * 10)])
   487        pc = p | Create(data)
   488        globally = (
   489            pc | 'Globally' >> beam.ApproximateQuantiles.Globally(
   490                3, input_batched=True))
   491        with_key = (
   492            pc | 'Globally with key' >> beam.ApproximateQuantiles.Globally(
   493                3, key=sum, input_batched=True))
   494        key_with_reversed = (
   495            pc | 'Globally with key and reversed' >>
   496            beam.ApproximateQuantiles.Globally(
   497                3, key=sum, reverse=True, input_batched=True))
   498        assert_that(
   499            globally,
   500            equal_to([[(0.0, 500), (49.9, 1), (99.9, 499)]]),
   501            label='checkGlobally')
   502        assert_that(
   503            with_key,
   504            equal_to([[(50.0, 0), (72.5, 225), (99.9, 499)]]),
   505            label='checkGloballyWithKey')
   506        assert_that(
   507            key_with_reversed,
   508            equal_to([[(99.9, 499), (72.5, 225), (50.0, 0)]]),
   509            label='checkGloballyWithKeyAndReversed')
   510  
   511    def test_batched_weighted_quantiles(self):
   512      with TestPipeline() as p:
   513        data = []
   514        for i in range(100):
   515          data.append([[(i / 10, abs(i - 500))
   516                        for i in range(i * 10, (i + 1) * 10)], [i] * 10])
   517        pc = p | Create(data)
   518        globally = (
   519            pc | 'Globally' >> beam.ApproximateQuantiles.Globally(
   520                3, weighted=True, input_batched=True))
   521        with_key = (
   522            pc | 'Globally with key' >> beam.ApproximateQuantiles.Globally(
   523                3, key=sum, weighted=True, input_batched=True))
   524        key_with_reversed = (
   525            pc | 'Globally with key and reversed' >>
   526            beam.ApproximateQuantiles.Globally(
   527                3, key=sum, reverse=True, weighted=True, input_batched=True))
   528        assert_that(
   529            globally,
   530            equal_to([[(0.0, 500), (70.8, 208), (99.9, 499)]]),
   531            label='checkGlobally')
   532        assert_that(
   533            with_key,
   534            equal_to([[(50.0, 0), (21.0, 290), (99.9, 499)]]),
   535            label='checkGloballyWithKey')
   536        assert_that(
   537            key_with_reversed,
   538            equal_to([[(99.9, 499), (21.0, 290), (50.0, 0)]]),
   539            label='checkGloballyWithKeyAndReversed')
   540  
   541    def test_quantiles_merge_accumulators(self):
   542      # This test exercises merging multiple buffers and approximation accuracy.
   543      # The max_num_elements is set to a small value to trigger buffers collapse
   544      # and interpolation. Under the conditions below, buffer_size=125 and
   545      # num_buffers=4, so we're only allowed to keep half of the input values.
   546      num_accumulators = 100
   547      num_quantiles = 5
   548      eps = 0.01
   549      max_num_elements = 1000
   550      combine_fn = ApproximateQuantilesCombineFn.create(
   551          num_quantiles, eps, max_num_elements)
   552      combine_fn_weighted = ApproximateQuantilesCombineFn.create(
   553          num_quantiles, eps, max_num_elements, weighted=True)
   554      data = list(range(1000))
   555      weights = list(reversed(range(1000)))
   556      step = math.ceil(len(data) / num_accumulators)
   557      accumulators = []
   558      accumulators_weighted = []
   559      for i in range(num_accumulators):
   560        accumulator = combine_fn.create_accumulator()
   561        accumulator_weighted = combine_fn_weighted.create_accumulator()
   562        for element, weight in zip(data[i*step:(i+1)*step],
   563                                   weights[i*step:(i+1)*step]):
   564          accumulator = combine_fn.add_input(accumulator, element)
   565          accumulator_weighted = combine_fn_weighted.add_input(
   566              accumulator_weighted, (element, weight))
   567        accumulators.append(accumulator)
   568        accumulators_weighted.append(accumulator_weighted)
   569      accumulator = combine_fn.merge_accumulators(accumulators)
   570      accumulator_weighted = combine_fn_weighted.merge_accumulators(
   571          accumulators_weighted)
   572      quantiles = combine_fn.extract_output(accumulator)
   573      quantiles_weighted = combine_fn_weighted.extract_output(
   574          accumulator_weighted)
   575  
   576      # In fact, the final accuracy is much higher than eps, but we test for a
   577      # minimal accuracy here.
   578      for q, actual_q in zip(quantiles, [0, 249, 499, 749, 999]):
   579        self.assertAlmostEqual(q, actual_q, delta=max_num_elements * eps)
   580      for q, actual_q in zip(quantiles_weighted, [0, 133, 292, 499, 999]):
   581        self.assertAlmostEqual(q, actual_q, delta=max_num_elements * eps)
   582  
   583    @staticmethod
   584    def _display_data_matcher(instance):
   585      expected_items = [
   586          DisplayDataItemMatcher('num_quantiles', instance._num_quantiles),
   587          DisplayDataItemMatcher('weighted', str(instance._weighted)),
   588          DisplayDataItemMatcher('key', str(instance._key.__name__)),
   589          DisplayDataItemMatcher('reverse', str(instance._reverse)),
   590          DisplayDataItemMatcher('input_batched', str(instance._input_batched)),
   591      ]
   592      return expected_items
   593  
   594    def test_global_display_data(self):
   595      transform = beam.ApproximateQuantiles.Globally(
   596          3, weighted=True, key=len, reverse=True)
   597      data = DisplayData.create_from(transform)
   598      expected_items = self._display_data_matcher(transform)
   599      hc.assert_that(data.items, hc.contains_inanyorder(*expected_items))
   600  
   601    def test_perkey_display_data(self):
   602      transform = beam.ApproximateQuantiles.PerKey(
   603          3, weighted=True, key=len, reverse=True)
   604      data = DisplayData.create_from(transform)
   605      expected_items = self._display_data_matcher(transform)
   606      hc.assert_that(data.items, hc.contains_inanyorder(*expected_items))
   607  
   608  
   609  def _build_quantilebuffer_test_data():
   610    """
   611    Test data taken from "Munro-Paterson Algorithm" reference values table of
   612    "Approximate Medians and other Quantiles in One Pass and with Limited Memory"
   613    paper. See ApproximateQuantilesCombineFn for paper reference.
   614    """
   615    epsilons = [0.1, 0.05, 0.01, 0.005, 0.001]
   616    maxElementExponents = [5, 6, 7, 8, 9]
   617    expectedNumBuffersValues = [[11, 14, 17, 21, 24], [11, 14, 17, 20, 23],
   618                                [9, 11, 14, 17, 21], [8, 11, 14, 17,
   619                                                      20], [6, 9, 11, 14, 17]]
   620    expectedBufferSizeValues = [[98, 123, 153, 96, 120], [98, 123, 153, 191, 239],
   621                                [391, 977, 1221, 1526,
   622                                 954], [782, 977, 1221, 1526,
   623                                        1908], [3125, 3907, 9766, 12208, 15259]]
   624    test_data = []
   625    i = 0
   626    for epsilon in epsilons:
   627      j = 0
   628      for maxElementExponent in maxElementExponents:
   629        test_data.append([
   630            epsilon, (10**maxElementExponent),
   631            expectedNumBuffersValues[i][j],
   632            expectedBufferSizeValues[i][j]
   633        ])
   634        j += 1
   635      i += 1
   636    return test_data
   637  
   638  
   639  class ApproximateQuantilesBufferTest(unittest.TestCase):
   640    """ Approximate Quantiles Buffer Tests to ensure we are calculating the
   641    optimal buffers."""
   642    @parameterized.expand(_build_quantilebuffer_test_data)
   643    def test_efficiency(
   644        self, epsilon, maxInputSize, expectedNumBuffers, expectedBufferSize):
   645      """
   646      Verify the buffers are efficiently calculated according to the reference
   647      table values.
   648      """
   649  
   650      combine_fn = ApproximateQuantilesCombineFn.create(
   651          num_quantiles=10, max_num_elements=maxInputSize, epsilon=epsilon)
   652      self.assertEqual(
   653          expectedNumBuffers, combine_fn._spec.num_buffers, "Number of buffers")
   654      self.assertEqual(
   655          expectedBufferSize, combine_fn._spec.buffer_size, "Buffer size")
   656  
   657    @parameterized.expand(_build_quantilebuffer_test_data)
   658    def test_correctness(self, epsilon, maxInputSize, *args):
   659      """
   660      Verify that buffers are correct according to the two constraint equations.
   661      """
   662      combine_fn = ApproximateQuantilesCombineFn.create(
   663          num_quantiles=10, max_num_elements=maxInputSize, epsilon=epsilon)
   664      b = combine_fn._spec.num_buffers
   665      k = combine_fn._spec.buffer_size
   666      n = maxInputSize
   667      self.assertLessEqual((b - 2) * (1 << (b - 2)) + 0.5, (epsilon * n),
   668                           '(b-2)2^(b-2) + 1/2 <= eN')
   669      self.assertGreaterEqual((k * 2)**(b - 1), n, 'k2^(b-1) >= N')
   670  
   671  
   672  if __name__ == '__main__':
   673    unittest.main()