github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/transforms/combiners_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Unit tests for our libraries of combine PTransforms."""
    19  # pytype: skip-file
    20  
    21  import itertools
    22  import random
    23  import unittest
    24  
    25  import hamcrest as hc
    26  import pytest
    27  
    28  import apache_beam as beam
    29  import apache_beam.transforms.combiners as combine
    30  from apache_beam.metrics import Metrics
    31  from apache_beam.metrics import MetricsFilter
    32  from apache_beam.options.pipeline_options import PipelineOptions
    33  from apache_beam.options.pipeline_options import StandardOptions
    34  from apache_beam.testing.test_pipeline import TestPipeline
    35  from apache_beam.testing.test_stream import TestStream
    36  from apache_beam.testing.util import assert_that
    37  from apache_beam.testing.util import equal_to
    38  from apache_beam.testing.util import equal_to_per_window
    39  from apache_beam.transforms import WindowInto
    40  from apache_beam.transforms import trigger
    41  from apache_beam.transforms import window
    42  from apache_beam.transforms.core import CombineGlobally
    43  from apache_beam.transforms.core import Create
    44  from apache_beam.transforms.core import Map
    45  from apache_beam.transforms.display import DisplayData
    46  from apache_beam.transforms.display_test import DisplayDataItemMatcher
    47  from apache_beam.transforms.ptransform import PTransform
    48  from apache_beam.transforms.trigger import AfterAll
    49  from apache_beam.transforms.trigger import AfterCount
    50  from apache_beam.transforms.trigger import AfterWatermark
    51  from apache_beam.transforms.window import FixedWindows
    52  from apache_beam.transforms.window import GlobalWindows
    53  from apache_beam.transforms.window import TimestampCombiner
    54  from apache_beam.transforms.window import TimestampedValue
    55  from apache_beam.typehints import TypeCheckError
    56  from apache_beam.utils.timestamp import Timestamp
    57  
    58  
    59  class SortedConcatWithCounters(beam.CombineFn):
    60    """CombineFn for incrementing three different counters:
    61       counter, distribution, gauge,
    62       at the same time concatenating words."""
    63    def __init__(self):
    64      beam.CombineFn.__init__(self)
    65      self.word_counter = Metrics.counter(self.__class__, 'word_counter')
    66      self.word_lengths_counter = Metrics.counter(self.__class__, 'word_lengths')
    67      self.word_lengths_dist = Metrics.distribution(
    68          self.__class__, 'word_len_dist')
    69      self.last_word_len = Metrics.gauge(self.__class__, 'last_word_len')
    70  
    71    def create_accumulator(self):
    72      return ''
    73  
    74    def add_input(self, acc, element):
    75      self.word_counter.inc(1)
    76      self.word_lengths_counter.inc(len(element))
    77      self.word_lengths_dist.update(len(element))
    78      self.last_word_len.set(len(element))
    79  
    80      return acc + element
    81  
    82    def merge_accumulators(self, accs):
    83      return ''.join(accs)
    84  
    85    def extract_output(self, acc):
    86      # The sorted acc became a list of characters
    87      # and has to be converted back to a string using join.
    88      return ''.join(sorted(acc))
    89  
    90  
    91  class CombineTest(unittest.TestCase):
    92    def test_builtin_combines(self):
    93      with TestPipeline() as pipeline:
    94  
    95        vals = [6, 3, 1, 1, 9, 1, 5, 2, 0, 6]
    96        mean = sum(vals) / float(len(vals))
    97        size = len(vals)
    98        timestamp = 0
    99  
   100        # First for global combines.
   101        pcoll = pipeline | 'start' >> Create(vals)
   102        result_mean = pcoll | 'mean' >> combine.Mean.Globally()
   103        result_count = pcoll | 'count' >> combine.Count.Globally()
   104        assert_that(result_mean, equal_to([mean]), label='assert:mean')
   105        assert_that(result_count, equal_to([size]), label='assert:size')
   106  
   107        # Now for global combines without default
   108        timestamped = pcoll | Map(lambda x: TimestampedValue(x, timestamp))
   109        windowed = timestamped | 'window' >> WindowInto(FixedWindows(60))
   110        result_windowed_mean = (
   111            windowed
   112            | 'mean-wo-defaults' >> combine.Mean.Globally().without_defaults())
   113        assert_that(
   114            result_windowed_mean,
   115            equal_to([mean]),
   116            label='assert:mean-wo-defaults')
   117        result_windowed_count = (
   118            windowed
   119            | 'count-wo-defaults' >> combine.Count.Globally().without_defaults())
   120        assert_that(
   121            result_windowed_count,
   122            equal_to([size]),
   123            label='assert:count-wo-defaults')
   124  
   125        # Again for per-key combines.
   126        pcoll = pipeline | 'start-perkey' >> Create([('a', x) for x in vals])
   127        result_key_mean = pcoll | 'mean-perkey' >> combine.Mean.PerKey()
   128        result_key_count = pcoll | 'count-perkey' >> combine.Count.PerKey()
   129        assert_that(result_key_mean, equal_to([('a', mean)]), label='key:mean')
   130        assert_that(result_key_count, equal_to([('a', size)]), label='key:size')
   131  
   132    def test_top(self):
   133      with TestPipeline() as pipeline:
   134        timestamp = 0
   135  
   136        # First for global combines.
   137        pcoll = pipeline | 'start' >> Create([6, 3, 1, 1, 9, 1, 5, 2, 0, 6])
   138        result_top = pcoll | 'top' >> combine.Top.Largest(5)
   139        result_bot = pcoll | 'bot' >> combine.Top.Smallest(4)
   140        assert_that(result_top, equal_to([[9, 6, 6, 5, 3]]), label='assert:top')
   141        assert_that(result_bot, equal_to([[0, 1, 1, 1]]), label='assert:bot')
   142  
   143        # Now for global combines without default
   144        timestamped = pcoll | Map(lambda x: TimestampedValue(x, timestamp))
   145        windowed = timestamped | 'window' >> WindowInto(FixedWindows(60))
   146        result_windowed_top = windowed | 'top-wo-defaults' >> combine.Top.Largest(
   147            5, has_defaults=False)
   148        result_windowed_bot = (
   149            windowed
   150            | 'bot-wo-defaults' >> combine.Top.Smallest(4, has_defaults=False))
   151        assert_that(
   152            result_windowed_top,
   153            equal_to([[9, 6, 6, 5, 3]]),
   154            label='assert:top-wo-defaults')
   155        assert_that(
   156            result_windowed_bot,
   157            equal_to([[0, 1, 1, 1]]),
   158            label='assert:bot-wo-defaults')
   159  
   160        # Again for per-key combines.
   161        pcoll = pipeline | 'start-perkey' >> Create(
   162            [('a', x) for x in [6, 3, 1, 1, 9, 1, 5, 2, 0, 6]])
   163        result_key_top = pcoll | 'top-perkey' >> combine.Top.LargestPerKey(5)
   164        result_key_bot = pcoll | 'bot-perkey' >> combine.Top.SmallestPerKey(4)
   165        assert_that(
   166            result_key_top, equal_to([('a', [9, 6, 6, 5, 3])]), label='key:top')
   167        assert_that(
   168            result_key_bot, equal_to([('a', [0, 1, 1, 1])]), label='key:bot')
   169  
   170    def test_empty_global_top(self):
   171      with TestPipeline() as p:
   172        assert_that(p | beam.Create([]) | combine.Top.Largest(10), equal_to([[]]))
   173  
   174    def test_sharded_top(self):
   175      elements = list(range(100))
   176      random.shuffle(elements)
   177  
   178      with TestPipeline() as pipeline:
   179        shards = [
   180            pipeline | 'Shard%s' % shard >> beam.Create(elements[shard::7])
   181            for shard in range(7)
   182        ]
   183        assert_that(
   184            shards | beam.Flatten() | combine.Top.Largest(10),
   185            equal_to([[99, 98, 97, 96, 95, 94, 93, 92, 91, 90]]))
   186  
   187    def test_top_key(self):
   188      self.assertEqual(['aa', 'bbb', 'c', 'dddd'] | combine.Top.Of(3, key=len),
   189                       [['dddd', 'bbb', 'aa']])
   190      self.assertEqual(['aa', 'bbb', 'c', 'dddd']
   191                       | combine.Top.Of(3, key=len, reverse=True),
   192                       [['c', 'aa', 'bbb']])
   193  
   194      self.assertEqual(['xc', 'zb', 'yd', 'wa']
   195                       | combine.Top.Largest(3, key=lambda x: x[-1]),
   196                       [['yd', 'xc', 'zb']])
   197      self.assertEqual(['xc', 'zb', 'yd', 'wa']
   198                       | combine.Top.Smallest(3, key=lambda x: x[-1]),
   199                       [['wa', 'zb', 'xc']])
   200  
   201      self.assertEqual([('a', x) for x in [1, 2, 3, 4, 1, 1]]
   202                       | combine.Top.LargestPerKey(3, key=lambda x: -x),
   203                       [('a', [1, 1, 1])])
   204      self.assertEqual([('a', x) for x in [1, 2, 3, 4, 1, 1]]
   205                       | combine.Top.SmallestPerKey(3, key=lambda x: -x),
   206                       [('a', [4, 3, 2])])
   207  
   208    def test_sharded_top_combine_fn(self):
   209      def test_combine_fn(combine_fn, shards, expected):
   210        accumulators = [
   211            combine_fn.add_inputs(combine_fn.create_accumulator(), shard)
   212            for shard in shards
   213        ]
   214        final_accumulator = combine_fn.merge_accumulators(accumulators)
   215        self.assertEqual(combine_fn.extract_output(final_accumulator), expected)
   216  
   217      test_combine_fn(combine.TopCombineFn(3), [range(10), range(10)], [9, 9, 8])
   218      test_combine_fn(
   219          combine.TopCombineFn(5), [range(1000), range(100), range(1001)],
   220          [1000, 999, 999, 998, 998])
   221  
   222    def test_combine_per_key_top_display_data(self):
   223      def individual_test_per_key_dd(combineFn):
   224        transform = beam.CombinePerKey(combineFn)
   225        dd = DisplayData.create_from(transform)
   226        expected_items = [
   227            DisplayDataItemMatcher('combine_fn', combineFn.__class__),
   228            DisplayDataItemMatcher('n', combineFn._n),
   229            DisplayDataItemMatcher('compare', combineFn._compare.__name__)
   230        ]
   231        hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
   232  
   233      individual_test_per_key_dd(combine.Largest(5))
   234      individual_test_per_key_dd(combine.Smallest(3))
   235      individual_test_per_key_dd(combine.TopCombineFn(8))
   236      individual_test_per_key_dd(combine.Largest(5))
   237  
   238    def test_combine_sample_display_data(self):
   239      def individual_test_per_key_dd(sampleFn, n):
   240        trs = [sampleFn(n)]
   241        for transform in trs:
   242          dd = DisplayData.create_from(transform)
   243          hc.assert_that(
   244              dd.items,
   245              hc.contains_inanyorder(DisplayDataItemMatcher('n', transform._n)))
   246  
   247      individual_test_per_key_dd(combine.Sample.FixedSizePerKey, 5)
   248      individual_test_per_key_dd(combine.Sample.FixedSizeGlobally, 5)
   249  
   250    def test_combine_globally_display_data(self):
   251      transform = beam.CombineGlobally(combine.Smallest(5))
   252      dd = DisplayData.create_from(transform)
   253      expected_items = [
   254          DisplayDataItemMatcher('combine_fn', combine.Smallest),
   255          DisplayDataItemMatcher('n', 5),
   256          DisplayDataItemMatcher('compare', 'gt')
   257      ]
   258      hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
   259  
   260    def test_basic_combiners_display_data(self):
   261      transform = beam.CombineGlobally(
   262          combine.TupleCombineFn(max, combine.MeanCombineFn(), sum))
   263      dd = DisplayData.create_from(transform)
   264      expected_items = [
   265          DisplayDataItemMatcher('combine_fn', combine.TupleCombineFn),
   266          DisplayDataItemMatcher('combiners', "['max', 'MeanCombineFn', 'sum']"),
   267          DisplayDataItemMatcher('merge_accumulators_batch_size', 333),
   268      ]
   269      hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
   270  
   271    def test_top_shorthands(self):
   272      with TestPipeline() as pipeline:
   273  
   274        pcoll = pipeline | 'start' >> Create([6, 3, 1, 1, 9, 1, 5, 2, 0, 6])
   275        result_top = pcoll | 'top' >> beam.CombineGlobally(combine.Largest(5))
   276        result_bot = pcoll | 'bot' >> beam.CombineGlobally(combine.Smallest(4))
   277        assert_that(result_top, equal_to([[9, 6, 6, 5, 3]]), label='assert:top')
   278        assert_that(result_bot, equal_to([[0, 1, 1, 1]]), label='assert:bot')
   279  
   280        pcoll = pipeline | 'start-perkey' >> Create(
   281            [('a', x) for x in [6, 3, 1, 1, 9, 1, 5, 2, 0, 6]])
   282        result_ktop = pcoll | 'top-perkey' >> beam.CombinePerKey(
   283            combine.Largest(5))
   284        result_kbot = pcoll | 'bot-perkey' >> beam.CombinePerKey(
   285            combine.Smallest(4))
   286        assert_that(result_ktop, equal_to([('a', [9, 6, 6, 5, 3])]), label='ktop')
   287        assert_that(result_kbot, equal_to([('a', [0, 1, 1, 1])]), label='kbot')
   288  
   289    def test_top_no_compact(self):
   290      class TopCombineFnNoCompact(combine.TopCombineFn):
   291        def compact(self, accumulator):
   292          return accumulator
   293  
   294      with TestPipeline() as pipeline:
   295        pcoll = pipeline | 'Start' >> Create([6, 3, 1, 1, 9, 1, 5, 2, 0, 6])
   296        result_top = pcoll | 'Top' >> beam.CombineGlobally(
   297            TopCombineFnNoCompact(5, key=lambda x: x))
   298        result_bot = pcoll | 'Bot' >> beam.CombineGlobally(
   299            TopCombineFnNoCompact(4, reverse=True))
   300        assert_that(result_top, equal_to([[9, 6, 6, 5, 3]]), label='Assert:Top')
   301        assert_that(result_bot, equal_to([[0, 1, 1, 1]]), label='Assert:Bot')
   302  
   303        pcoll = pipeline | 'Start-Perkey' >> Create(
   304            [('a', x) for x in [6, 3, 1, 1, 9, 1, 5, 2, 0, 6]])
   305        result_ktop = pcoll | 'Top-PerKey' >> beam.CombinePerKey(
   306            TopCombineFnNoCompact(5, key=lambda x: x))
   307        result_kbot = pcoll | 'Bot-PerKey' >> beam.CombinePerKey(
   308            TopCombineFnNoCompact(4, reverse=True))
   309        assert_that(result_ktop, equal_to([('a', [9, 6, 6, 5, 3])]), label='KTop')
   310        assert_that(result_kbot, equal_to([('a', [0, 1, 1, 1])]), label='KBot')
   311  
   312    def test_global_sample(self):
   313      def is_good_sample(actual):
   314        assert len(actual) == 1
   315        assert sorted(actual[0]) in [[1, 1, 2], [1, 2, 2]], actual
   316  
   317      with TestPipeline() as pipeline:
   318        timestamp = 0
   319        pcoll = pipeline | 'start' >> Create([1, 1, 2, 2])
   320  
   321        # Now for global combines without default
   322        timestamped = pcoll | Map(lambda x: TimestampedValue(x, timestamp))
   323        windowed = timestamped | 'window' >> WindowInto(FixedWindows(60))
   324  
   325        for ix in range(9):
   326          assert_that(
   327              pcoll | 'sample-%d' % ix >> combine.Sample.FixedSizeGlobally(3),
   328              is_good_sample,
   329              label='check-%d' % ix)
   330          result_windowed = (
   331              windowed
   332              | 'sample-wo-defaults-%d' % ix >>
   333              combine.Sample.FixedSizeGlobally(3).without_defaults())
   334          assert_that(
   335              result_windowed, is_good_sample, label='check-wo-defaults-%d' % ix)
   336  
   337    def test_per_key_sample(self):
   338      with TestPipeline() as pipeline:
   339        pcoll = pipeline | 'start-perkey' >> Create(
   340            sum(([(i, 1), (i, 1), (i, 2), (i, 2)] for i in range(9)), []))
   341        result = pcoll | 'sample' >> combine.Sample.FixedSizePerKey(3)
   342  
   343        def matcher():
   344          def match(actual):
   345            for _, samples in actual:
   346              equal_to([3])([len(samples)])
   347              num_ones = sum(1 for x in samples if x == 1)
   348              num_twos = sum(1 for x in samples if x == 2)
   349              equal_to([1, 2])([num_ones, num_twos])
   350  
   351          return match
   352  
   353        assert_that(result, matcher())
   354  
   355    def test_tuple_combine_fn(self):
   356      with TestPipeline() as p:
   357        result = (
   358            p
   359            | Create([('a', 100, 0.0), ('b', 10, -1), ('c', 1, 100)])
   360            | beam.CombineGlobally(
   361                combine.TupleCombineFn(max, combine.MeanCombineFn(),
   362                                       sum)).without_defaults())
   363        assert_that(result, equal_to([('c', 111.0 / 3, 99.0)]))
   364  
   365    def test_tuple_combine_fn_without_defaults(self):
   366      with TestPipeline() as p:
   367        result = (
   368            p
   369            | Create([1, 1, 2, 3])
   370            | beam.CombineGlobally(
   371                combine.TupleCombineFn(
   372                    min, combine.MeanCombineFn(),
   373                    max).with_common_input()).without_defaults())
   374        assert_that(result, equal_to([(1, 7.0 / 4, 3)]))
   375  
   376    def test_empty_tuple_combine_fn(self):
   377      with TestPipeline() as p:
   378        result = (
   379            p
   380            | Create([(), (), ()])
   381            | beam.CombineGlobally(combine.TupleCombineFn()))
   382        assert_that(result, equal_to([()]))
   383  
   384    def test_tuple_combine_fn_batched_merge(self):
   385      num_combine_fns = 10
   386      max_num_accumulators_in_memory = 30
   387      # Maximum number of accumulator tuples in memory - 1 for the merge result.
   388      merge_accumulators_batch_size = (
   389          max_num_accumulators_in_memory // num_combine_fns - 1)
   390      num_accumulator_tuples_to_merge = 20
   391  
   392      class CountedAccumulator:
   393        count = 0
   394        oom = False
   395  
   396        def __init__(self):
   397          if CountedAccumulator.count > max_num_accumulators_in_memory:
   398            CountedAccumulator.oom = True
   399          else:
   400            CountedAccumulator.count += 1
   401  
   402      class CountedAccumulatorCombineFn(beam.CombineFn):
   403        def create_accumulator(self):
   404          return CountedAccumulator()
   405  
   406        def merge_accumulators(self, accumulators):
   407          CountedAccumulator.count += 1
   408          for _ in accumulators:
   409            CountedAccumulator.count -= 1
   410  
   411      combine_fn = combine.TupleCombineFn(
   412          *[CountedAccumulatorCombineFn() for _ in range(num_combine_fns)],
   413          merge_accumulators_batch_size=merge_accumulators_batch_size)
   414      combine_fn.merge_accumulators(
   415          combine_fn.create_accumulator()
   416          for _ in range(num_accumulator_tuples_to_merge))
   417      assert not CountedAccumulator.oom
   418  
   419    def test_to_list_and_to_dict1(self):
   420      with TestPipeline() as pipeline:
   421        the_list = [6, 3, 1, 1, 9, 1, 5, 2, 0, 6]
   422        timestamp = 0
   423        pcoll = pipeline | 'start' >> Create(the_list)
   424        result = pcoll | 'to list' >> combine.ToList()
   425  
   426        # Now for global combines without default
   427        timestamped = pcoll | Map(lambda x: TimestampedValue(x, timestamp))
   428        windowed = timestamped | 'window' >> WindowInto(FixedWindows(60))
   429        result_windowed = (
   430            windowed
   431            | 'to list wo defaults' >> combine.ToList().without_defaults())
   432  
   433        def matcher(expected):
   434          def match(actual):
   435            equal_to(expected[0])(actual[0])
   436  
   437          return match
   438  
   439        assert_that(result, matcher([the_list]))
   440        assert_that(
   441            result_windowed, matcher([the_list]), label='to-list-wo-defaults')
   442  
   443    def test_to_list_and_to_dict2(self):
   444      with TestPipeline() as pipeline:
   445        pairs = [(1, 2), (3, 4), (5, 6)]
   446        timestamp = 0
   447        pcoll = pipeline | 'start-pairs' >> Create(pairs)
   448        result = pcoll | 'to dict' >> combine.ToDict()
   449  
   450        # Now for global combines without default
   451        timestamped = pcoll | Map(lambda x: TimestampedValue(x, timestamp))
   452        windowed = timestamped | 'window' >> WindowInto(FixedWindows(60))
   453        result_windowed = (
   454            windowed
   455            | 'to dict wo defaults' >> combine.ToDict().without_defaults())
   456  
   457        def matcher():
   458          def match(actual):
   459            equal_to([1])([len(actual)])
   460            equal_to(pairs)(actual[0].items())
   461  
   462          return match
   463  
   464        assert_that(result, matcher())
   465        assert_that(result_windowed, matcher(), label='to-dict-wo-defaults')
   466  
   467    def test_to_set(self):
   468      pipeline = TestPipeline()
   469      the_list = [6, 3, 1, 1, 9, 1, 5, 2, 0, 6]
   470      timestamp = 0
   471      pcoll = pipeline | 'start' >> Create(the_list)
   472      result = pcoll | 'to set' >> combine.ToSet()
   473  
   474      # Now for global combines without default
   475      timestamped = pcoll | Map(lambda x: TimestampedValue(x, timestamp))
   476      windowed = timestamped | 'window' >> WindowInto(FixedWindows(60))
   477      result_windowed = (
   478          windowed
   479          | 'to set wo defaults' >> combine.ToSet().without_defaults())
   480  
   481      def matcher(expected):
   482        def match(actual):
   483          equal_to(expected[0])(actual[0])
   484  
   485        return match
   486  
   487      assert_that(result, matcher(set(the_list)))
   488      assert_that(
   489          result_windowed, matcher(set(the_list)), label='to-set-wo-defaults')
   490  
   491    def test_combine_globally_with_default(self):
   492      with TestPipeline() as p:
   493        assert_that(p | Create([]) | CombineGlobally(sum), equal_to([0]))
   494  
   495    def test_combine_globally_without_default(self):
   496      with TestPipeline() as p:
   497        result = p | Create([]) | CombineGlobally(sum).without_defaults()
   498        assert_that(result, equal_to([]))
   499  
   500    def test_combine_globally_with_default_side_input(self):
   501      class SideInputCombine(PTransform):
   502        def expand(self, pcoll):
   503          side = pcoll | CombineGlobally(sum).as_singleton_view()
   504          main = pcoll.pipeline | Create([None])
   505          return main | Map(lambda _, s: s, side)
   506  
   507      with TestPipeline() as p:
   508        result1 = p | 'i1' >> Create([]) | 'c1' >> SideInputCombine()
   509        result2 = p | 'i2' >> Create([1, 2, 3, 4]) | 'c2' >> SideInputCombine()
   510        assert_that(result1, equal_to([0]), label='r1')
   511        assert_that(result2, equal_to([10]), label='r2')
   512  
   513    def test_hot_key_fanout(self):
   514      with TestPipeline() as p:
   515        result = (
   516            p
   517            | beam.Create(itertools.product(['hot', 'cold'], range(10)))
   518            | beam.CombinePerKey(combine.MeanCombineFn()).with_hot_key_fanout(
   519                lambda key: (key == 'hot') * 5))
   520        assert_that(result, equal_to([('hot', 4.5), ('cold', 4.5)]))
   521  
   522    def test_hot_key_fanout_sharded(self):
   523      # Lots of elements with the same key with varying/no fanout.
   524      with TestPipeline() as p:
   525        elements = [(None, e) for e in range(1000)]
   526        random.shuffle(elements)
   527        shards = [
   528            p | "Shard%s" % shard >> beam.Create(elements[shard::20])
   529            for shard in range(20)
   530        ]
   531        result = (
   532            shards
   533            | beam.Flatten()
   534            | beam.CombinePerKey(combine.MeanCombineFn()).with_hot_key_fanout(
   535                lambda key: random.randrange(0, 5)))
   536        assert_that(result, equal_to([(None, 499.5)]))
   537  
   538    def test_global_fanout(self):
   539      with TestPipeline() as p:
   540        result = (
   541            p
   542            | beam.Create(range(100))
   543            | beam.CombineGlobally(combine.MeanCombineFn()).with_fanout(11))
   544        assert_that(result, equal_to([49.5]))
   545  
   546    def test_combining_with_accumulation_mode_and_fanout(self):
   547      # PCollection will contain elements from 1 to 5.
   548      elements = [i for i in range(1, 6)]
   549  
   550      ts = TestStream().advance_watermark_to(0)
   551      for i in elements:
   552        ts.add_elements([i])
   553      ts.advance_watermark_to_infinity()
   554  
   555      options = PipelineOptions()
   556      options.view_as(StandardOptions).streaming = True
   557      with TestPipeline(options=options) as p:
   558        result = (
   559            p
   560            | ts
   561            | beam.WindowInto(
   562                GlobalWindows(),
   563                accumulation_mode=trigger.AccumulationMode.ACCUMULATING,
   564                trigger=AfterWatermark(early=AfterAll(AfterCount(1))))
   565            | beam.CombineGlobally(sum).without_defaults().with_fanout(2))
   566  
   567        def has_expected_values(actual):
   568          from hamcrest.core import assert_that as hamcrest_assert
   569          from hamcrest.library.collection import contains
   570          from hamcrest.library.collection import only_contains
   571          ordered = sorted(actual)
   572          # Early firings.
   573          hamcrest_assert(ordered[:4], contains(1, 3, 6, 10))
   574          # Different runners have different number of 15s, but there should
   575          # be at least one 15.
   576          hamcrest_assert(ordered[4:], only_contains(15))
   577  
   578        assert_that(result, has_expected_values)
   579  
   580    def test_combining_with_sliding_windows_and_fanout_raises_error(self):
   581      options = PipelineOptions()
   582      options.view_as(StandardOptions).streaming = True
   583      with self.assertRaises(ValueError):
   584        with TestPipeline(options=options) as p:
   585          _ = (
   586              p
   587              | beam.Create([
   588                  window.TimestampedValue(0, Timestamp(seconds=1666707510)),
   589                  window.TimestampedValue(1, Timestamp(seconds=1666707511)),
   590                  window.TimestampedValue(2, Timestamp(seconds=1666707512)),
   591                  window.TimestampedValue(3, Timestamp(seconds=1666707513)),
   592                  window.TimestampedValue(5, Timestamp(seconds=1666707515)),
   593                  window.TimestampedValue(6, Timestamp(seconds=1666707516)),
   594                  window.TimestampedValue(7, Timestamp(seconds=1666707517)),
   595                  window.TimestampedValue(8, Timestamp(seconds=1666707518))
   596              ])
   597              | beam.WindowInto(window.SlidingWindows(10, 5))
   598              | beam.CombineGlobally(beam.combiners.ToListCombineFn()).
   599              without_defaults().with_fanout(7))
   600  
   601    def test_MeanCombineFn_combine(self):
   602      with TestPipeline() as p:
   603        input = (
   604            p
   605            | beam.Create([('a', 1), ('a', 1), ('a', 4), ('b', 1), ('b', 13)]))
   606        # The mean of all values regardless of key.
   607        global_mean = (
   608            input
   609            | beam.Values()
   610            | beam.CombineGlobally(combine.MeanCombineFn()))
   611  
   612        # The (key, mean) pairs for all keys.
   613        mean_per_key = (input | beam.CombinePerKey(combine.MeanCombineFn()))
   614  
   615        expected_mean_per_key = [('a', 2), ('b', 7)]
   616        assert_that(global_mean, equal_to([4]), label='global mean')
   617        assert_that(
   618            mean_per_key, equal_to(expected_mean_per_key), label='mean per key')
   619  
   620    def test_MeanCombineFn_combine_empty(self):
   621      # For each element in a PCollection, if it is float('NaN'), then emits
   622      # a string 'NaN', otherwise emits str(element).
   623  
   624      with TestPipeline() as p:
   625        input = (p | beam.Create([]))
   626  
   627        # Compute the mean of all values in the PCollection,
   628        # then format the mean. Since the Pcollection is empty,
   629        # the mean is float('NaN'), and is formatted to be a string 'NaN'.
   630        global_mean = (
   631            input
   632            | beam.Values()
   633            | beam.CombineGlobally(combine.MeanCombineFn())
   634            | beam.Map(str))
   635  
   636        mean_per_key = (input | beam.CombinePerKey(combine.MeanCombineFn()))
   637  
   638        # We can't compare one float('NaN') with another float('NaN'),
   639        # but we can compare one 'nan' string with another string.
   640        assert_that(global_mean, equal_to(['nan']), label='global mean')
   641        assert_that(mean_per_key, equal_to([]), label='mean per key')
   642  
   643    def test_sessions_combine(self):
   644      with TestPipeline() as p:
   645        input = (
   646            p
   647            | beam.Create([('c', 1), ('c', 9), ('c', 12), ('d', 2), ('d', 4)])
   648            | beam.MapTuple(lambda k, v: window.TimestampedValue((k, v), v))
   649            | beam.WindowInto(window.Sessions(4)))
   650  
   651        global_sum = (
   652            input
   653            | beam.Values()
   654            | beam.CombineGlobally(sum).without_defaults())
   655        sum_per_key = input | beam.CombinePerKey(sum)
   656  
   657        # The first window has 3 elements: ('c', 1), ('d', 2), ('d', 4).
   658        # The second window has 2 elements: ('c', 9), ('c', 12).
   659        assert_that(global_sum, equal_to([7, 21]), label='global sum')
   660        assert_that(
   661            sum_per_key,
   662            equal_to([('c', 1), ('c', 21), ('d', 6)]),
   663            label='sum per key')
   664  
   665    def test_fixed_windows_combine(self):
   666      with TestPipeline() as p:
   667        input = (
   668            p
   669            | beam.Create([('c', 1), ('c', 2), ('c', 10), ('d', 5), ('d', 8),
   670                           ('d', 9)])
   671            | beam.MapTuple(lambda k, v: window.TimestampedValue((k, v), v))
   672            | beam.WindowInto(window.FixedWindows(4)))
   673  
   674        global_sum = (
   675            input
   676            | beam.Values()
   677            | beam.CombineGlobally(sum).without_defaults())
   678        sum_per_key = input | beam.CombinePerKey(sum)
   679  
   680        # The first window has 2 elements: ('c', 1), ('c', 2).
   681        # The second window has 1 elements: ('d', 5).
   682        # The third window has 3 elements: ('c', 10), ('d', 8), ('d', 9).
   683        assert_that(global_sum, equal_to([3, 5, 27]), label='global sum')
   684        assert_that(
   685            sum_per_key,
   686            equal_to([('c', 3), ('c', 10), ('d', 5), ('d', 17)]),
   687            label='sum per key')
   688  
   689    # Test that three different kinds of metrics work with a customized
   690    # SortedConcatWithCounters CombineFn.
   691    def test_custormized_counters_in_combine_fn(self):
   692      p = TestPipeline()
   693      input = (
   694          p
   695          | beam.Create([('key1', 'a'), ('key1', 'ab'), ('key1', 'abc'),
   696                         ('key2', 'uvxy'), ('key2', 'uvxyz')]))
   697  
   698      # The result of concatenating all values regardless of key.
   699      global_concat = (
   700          input
   701          | beam.Values()
   702          | beam.CombineGlobally(SortedConcatWithCounters()))
   703  
   704      # The (key, concatenated_string) pairs for all keys.
   705      concat_per_key = (input | beam.CombinePerKey(SortedConcatWithCounters()))
   706  
   707      # Verify the concatenated strings are correct.
   708      expected_concat_per_key = [('key1', 'aaabbc'), ('key2', 'uuvvxxyyz')]
   709      assert_that(
   710          global_concat, equal_to(['aaabbcuuvvxxyyz']), label='global concat')
   711      assert_that(
   712          concat_per_key,
   713          equal_to(expected_concat_per_key),
   714          label='concat per key')
   715  
   716      result = p.run()
   717      result.wait_until_finish()
   718  
   719      # Verify the values of metrics are correct.
   720      word_counter_filter = MetricsFilter().with_name('word_counter')
   721      query_result = result.metrics().query(word_counter_filter)
   722      if query_result['counters']:
   723        word_counter = query_result['counters'][0]
   724        self.assertEqual(word_counter.result, 5)
   725  
   726      word_lengths_filter = MetricsFilter().with_name('word_lengths')
   727      query_result = result.metrics().query(word_lengths_filter)
   728      if query_result['counters']:
   729        word_lengths = query_result['counters'][0]
   730        self.assertEqual(word_lengths.result, 15)
   731  
   732      word_len_dist_filter = MetricsFilter().with_name('word_len_dist')
   733      query_result = result.metrics().query(word_len_dist_filter)
   734      if query_result['distributions']:
   735        word_len_dist = query_result['distributions'][0]
   736        self.assertEqual(word_len_dist.result.mean, 3)
   737  
   738      last_word_len_filter = MetricsFilter().with_name('last_word_len')
   739      query_result = result.metrics().query(last_word_len_filter)
   740      if query_result['gauges']:
   741        last_word_len = query_result['gauges'][0]
   742        self.assertIn(last_word_len.result.value, [1, 2, 3, 4, 5])
   743  
   744    # Test that three different kinds of metrics work with the customized
   745    # SortedConcatWithCounters CombineFn when the PCollection is empty.
   746    def test_custormized_counters_in_combine_fn_empty(self):
   747      p = TestPipeline()
   748      input = p | beam.Create([])
   749  
   750      # The result of concatenating all values regardless of key.
   751      global_concat = (
   752          input
   753          | beam.Values()
   754          | beam.CombineGlobally(SortedConcatWithCounters()))
   755  
   756      # The (key, concatenated_string) pairs for all keys.
   757      concat_per_key = (input | beam.CombinePerKey(SortedConcatWithCounters()))
   758  
   759      # Verify the concatenated strings are correct.
   760      assert_that(global_concat, equal_to(['']), label='global concat')
   761      assert_that(concat_per_key, equal_to([]), label='concat per key')
   762  
   763      result = p.run()
   764      result.wait_until_finish()
   765  
   766      # Verify the values of metrics are correct.
   767      word_counter_filter = MetricsFilter().with_name('word_counter')
   768      query_result = result.metrics().query(word_counter_filter)
   769      if query_result['counters']:
   770        word_counter = query_result['counters'][0]
   771        self.assertEqual(word_counter.result, 0)
   772  
   773      word_lengths_filter = MetricsFilter().with_name('word_lengths')
   774      query_result = result.metrics().query(word_lengths_filter)
   775      if query_result['counters']:
   776        word_lengths = query_result['counters'][0]
   777        self.assertEqual(word_lengths.result, 0)
   778  
   779      word_len_dist_filter = MetricsFilter().with_name('word_len_dist')
   780      query_result = result.metrics().query(word_len_dist_filter)
   781      if query_result['distributions']:
   782        word_len_dist = query_result['distributions'][0]
   783        self.assertEqual(word_len_dist.result.count, 0)
   784  
   785      last_word_len_filter = MetricsFilter().with_name('last_word_len')
   786      query_result = result.metrics().query(last_word_len_filter)
   787  
   788      # No element has ever been recorded.
   789      self.assertFalse(query_result['gauges'])
   790  
   791  
   792  class LatestTest(unittest.TestCase):
   793    def test_globally(self):
   794      l = [
   795          window.TimestampedValue(3, 100),
   796          window.TimestampedValue(1, 200),
   797          window.TimestampedValue(2, 300)
   798      ]
   799      with TestPipeline() as p:
   800        # Map(lambda x: x) PTransform is added after Create here, because when
   801        # a PCollection of TimestampedValues is created with Create PTransform,
   802        # the timestamps are not assigned to it. Adding a Map forces the
   803        # PCollection to go through a DoFn so that the PCollection consists of
   804        # the elements with timestamps assigned to them instead of a PCollection
   805        # of TimestampedValue(element, timestamp).
   806        pcoll = p | Create(l) | Map(lambda x: x)
   807        latest = pcoll | combine.Latest.Globally()
   808        assert_that(latest, equal_to([2]))
   809  
   810        # Now for global combines without default
   811        windowed = pcoll | 'window' >> WindowInto(FixedWindows(180))
   812        result_windowed = (
   813            windowed
   814            |
   815            'latest wo defaults' >> combine.Latest.Globally().without_defaults())
   816  
   817        assert_that(result_windowed, equal_to([3, 2]), label='latest-wo-defaults')
   818  
   819    def test_globally_empty(self):
   820      l = []
   821      with TestPipeline() as p:
   822        pc = p | Create(l) | Map(lambda x: x)
   823        latest = pc | combine.Latest.Globally()
   824        assert_that(latest, equal_to([None]))
   825  
   826    def test_per_key(self):
   827      l = [
   828          window.TimestampedValue(('a', 1), 300),
   829          window.TimestampedValue(('b', 3), 100),
   830          window.TimestampedValue(('a', 2), 200)
   831      ]
   832      with TestPipeline() as p:
   833        pc = p | Create(l) | Map(lambda x: x)
   834        latest = pc | combine.Latest.PerKey()
   835        assert_that(latest, equal_to([('a', 1), ('b', 3)]))
   836  
   837    def test_per_key_empty(self):
   838      l = []
   839      with TestPipeline() as p:
   840        pc = p | Create(l) | Map(lambda x: x)
   841        latest = pc | combine.Latest.PerKey()
   842        assert_that(latest, equal_to([]))
   843  
   844  
   845  class LatestCombineFnTest(unittest.TestCase):
   846    def setUp(self):
   847      self.fn = combine.LatestCombineFn()
   848  
   849    def test_create_accumulator(self):
   850      accumulator = self.fn.create_accumulator()
   851      self.assertEqual(accumulator, (None, window.MIN_TIMESTAMP))
   852  
   853    def test_add_input(self):
   854      accumulator = self.fn.create_accumulator()
   855      element = (1, 100)
   856      new_accumulator = self.fn.add_input(accumulator, element)
   857      self.assertEqual(new_accumulator, (1, 100))
   858  
   859    def test_merge_accumulators(self):
   860      accumulators = [(2, 400), (5, 100), (9, 200)]
   861      merged_accumulator = self.fn.merge_accumulators(accumulators)
   862      self.assertEqual(merged_accumulator, (2, 400))
   863  
   864    def test_extract_output(self):
   865      accumulator = (1, 100)
   866      output = self.fn.extract_output(accumulator)
   867      self.assertEqual(output, 1)
   868  
   869    def test_with_input_types_decorator_violation(self):
   870      l_int = [1, 2, 3]
   871      l_dict = [{'a': 3}, {'g': 5}, {'r': 8}]
   872      l_3_tuple = [(12, 31, 41), (12, 34, 34), (84, 92, 74)]
   873  
   874      with self.assertRaises(TypeCheckError):
   875        with TestPipeline() as p:
   876          pc = p | Create(l_int)
   877          _ = pc | beam.CombineGlobally(self.fn)
   878  
   879      with self.assertRaises(TypeCheckError):
   880        with TestPipeline() as p:
   881          pc = p | Create(l_dict)
   882          _ = pc | beam.CombineGlobally(self.fn)
   883  
   884      with self.assertRaises(TypeCheckError):
   885        with TestPipeline() as p:
   886          pc = p | Create(l_3_tuple)
   887          _ = pc | beam.CombineGlobally(self.fn)
   888  
   889  
   890  @pytest.mark.it_validatesrunner
   891  class CombineValuesTest(unittest.TestCase):
   892    def test_gbk_immediately_followed_by_combine(self):
   893      def merge(vals):
   894        return "".join(vals)
   895  
   896      with TestPipeline() as p:
   897        result = (
   898            p \
   899            | Create([("key1", "foo"), ("key2", "bar"), ("key1", "foo")],
   900                      reshuffle=False) \
   901            | beam.GroupByKey() \
   902            | beam.CombineValues(merge) \
   903            | beam.MapTuple(lambda k, v: '{}: {}'.format(k, v)))
   904  
   905        assert_that(result, equal_to(['key1: foofoo', 'key2: bar']))
   906  
   907  
   908  #
   909  # Test cases for streaming.
   910  #
   911  @pytest.mark.it_validatesrunner
   912  class TimestampCombinerTest(unittest.TestCase):
   913    def test_combiner_earliest(self):
   914      """Test TimestampCombiner with EARLIEST."""
   915      options = PipelineOptions(streaming=True)
   916      with TestPipeline(options=options) as p:
   917        result = (
   918            p
   919            | TestStream().add_elements([window.TimestampedValue(
   920                ('k', 100), 2)]).add_elements(
   921                    [window.TimestampedValue(
   922                        ('k', 400), 7)]).advance_watermark_to_infinity()
   923            | beam.WindowInto(
   924                window.FixedWindows(10),
   925                timestamp_combiner=TimestampCombiner.OUTPUT_AT_EARLIEST)
   926            | beam.CombinePerKey(sum))
   927  
   928        records = (
   929            result
   930            | beam.Map(lambda e, ts=beam.DoFn.TimestampParam: (e, ts)))
   931  
   932        # All the KV pairs are applied GBK using EARLIEST timestamp for the same
   933        # key.
   934        expected_window_to_elements = {
   935            window.IntervalWindow(0, 10): [
   936                (('k', 500), Timestamp(2)),
   937            ],
   938        }
   939  
   940        assert_that(
   941            records,
   942            equal_to_per_window(expected_window_to_elements),
   943            use_global_window=False,
   944            label='assert per window')
   945  
   946    def test_combiner_latest(self):
   947      """Test TimestampCombiner with LATEST."""
   948      options = PipelineOptions(streaming=True)
   949      with TestPipeline(options=options) as p:
   950        result = (
   951            p
   952            | TestStream().add_elements([window.TimestampedValue(
   953                ('k', 100), 2)]).add_elements(
   954                    [window.TimestampedValue(
   955                        ('k', 400), 7)]).advance_watermark_to_infinity()
   956            | beam.WindowInto(
   957                window.FixedWindows(10),
   958                timestamp_combiner=TimestampCombiner.OUTPUT_AT_LATEST)
   959            | beam.CombinePerKey(sum))
   960  
   961        records = (
   962            result
   963            | beam.Map(lambda e, ts=beam.DoFn.TimestampParam: (e, ts)))
   964  
   965        # All the KV pairs are applied GBK using LATEST timestamp for
   966        # the same key.
   967        expected_window_to_elements = {
   968            window.IntervalWindow(0, 10): [
   969                (('k', 500), Timestamp(7)),
   970            ],
   971        }
   972  
   973        assert_that(
   974            records,
   975            equal_to_per_window(expected_window_to_elements),
   976            use_global_window=False,
   977            label='assert per window')
   978  
   979  
   980  if __name__ == '__main__':
   981    unittest.main()