github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/transforms/combiners_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Unit tests for our libraries of combine PTransforms.""" 19 # pytype: skip-file 20 21 import itertools 22 import random 23 import unittest 24 25 import hamcrest as hc 26 import pytest 27 28 import apache_beam as beam 29 import apache_beam.transforms.combiners as combine 30 from apache_beam.metrics import Metrics 31 from apache_beam.metrics import MetricsFilter 32 from apache_beam.options.pipeline_options import PipelineOptions 33 from apache_beam.options.pipeline_options import StandardOptions 34 from apache_beam.testing.test_pipeline import TestPipeline 35 from apache_beam.testing.test_stream import TestStream 36 from apache_beam.testing.util import assert_that 37 from apache_beam.testing.util import equal_to 38 from apache_beam.testing.util import equal_to_per_window 39 from apache_beam.transforms import WindowInto 40 from apache_beam.transforms import trigger 41 from apache_beam.transforms import window 42 from apache_beam.transforms.core import CombineGlobally 43 from apache_beam.transforms.core import Create 44 from apache_beam.transforms.core import Map 45 from apache_beam.transforms.display import DisplayData 46 from apache_beam.transforms.display_test import DisplayDataItemMatcher 47 from apache_beam.transforms.ptransform import PTransform 48 from apache_beam.transforms.trigger import AfterAll 49 from apache_beam.transforms.trigger import AfterCount 50 from apache_beam.transforms.trigger import AfterWatermark 51 from apache_beam.transforms.window import FixedWindows 52 from apache_beam.transforms.window import GlobalWindows 53 from apache_beam.transforms.window import TimestampCombiner 54 from apache_beam.transforms.window import TimestampedValue 55 from apache_beam.typehints import TypeCheckError 56 from apache_beam.utils.timestamp import Timestamp 57 58 59 class SortedConcatWithCounters(beam.CombineFn): 60 """CombineFn for incrementing three different counters: 61 counter, distribution, gauge, 62 at the same time concatenating words.""" 63 def __init__(self): 64 beam.CombineFn.__init__(self) 65 self.word_counter = Metrics.counter(self.__class__, 'word_counter') 66 self.word_lengths_counter = Metrics.counter(self.__class__, 'word_lengths') 67 self.word_lengths_dist = Metrics.distribution( 68 self.__class__, 'word_len_dist') 69 self.last_word_len = Metrics.gauge(self.__class__, 'last_word_len') 70 71 def create_accumulator(self): 72 return '' 73 74 def add_input(self, acc, element): 75 self.word_counter.inc(1) 76 self.word_lengths_counter.inc(len(element)) 77 self.word_lengths_dist.update(len(element)) 78 self.last_word_len.set(len(element)) 79 80 return acc + element 81 82 def merge_accumulators(self, accs): 83 return ''.join(accs) 84 85 def extract_output(self, acc): 86 # The sorted acc became a list of characters 87 # and has to be converted back to a string using join. 88 return ''.join(sorted(acc)) 89 90 91 class CombineTest(unittest.TestCase): 92 def test_builtin_combines(self): 93 with TestPipeline() as pipeline: 94 95 vals = [6, 3, 1, 1, 9, 1, 5, 2, 0, 6] 96 mean = sum(vals) / float(len(vals)) 97 size = len(vals) 98 timestamp = 0 99 100 # First for global combines. 101 pcoll = pipeline | 'start' >> Create(vals) 102 result_mean = pcoll | 'mean' >> combine.Mean.Globally() 103 result_count = pcoll | 'count' >> combine.Count.Globally() 104 assert_that(result_mean, equal_to([mean]), label='assert:mean') 105 assert_that(result_count, equal_to([size]), label='assert:size') 106 107 # Now for global combines without default 108 timestamped = pcoll | Map(lambda x: TimestampedValue(x, timestamp)) 109 windowed = timestamped | 'window' >> WindowInto(FixedWindows(60)) 110 result_windowed_mean = ( 111 windowed 112 | 'mean-wo-defaults' >> combine.Mean.Globally().without_defaults()) 113 assert_that( 114 result_windowed_mean, 115 equal_to([mean]), 116 label='assert:mean-wo-defaults') 117 result_windowed_count = ( 118 windowed 119 | 'count-wo-defaults' >> combine.Count.Globally().without_defaults()) 120 assert_that( 121 result_windowed_count, 122 equal_to([size]), 123 label='assert:count-wo-defaults') 124 125 # Again for per-key combines. 126 pcoll = pipeline | 'start-perkey' >> Create([('a', x) for x in vals]) 127 result_key_mean = pcoll | 'mean-perkey' >> combine.Mean.PerKey() 128 result_key_count = pcoll | 'count-perkey' >> combine.Count.PerKey() 129 assert_that(result_key_mean, equal_to([('a', mean)]), label='key:mean') 130 assert_that(result_key_count, equal_to([('a', size)]), label='key:size') 131 132 def test_top(self): 133 with TestPipeline() as pipeline: 134 timestamp = 0 135 136 # First for global combines. 137 pcoll = pipeline | 'start' >> Create([6, 3, 1, 1, 9, 1, 5, 2, 0, 6]) 138 result_top = pcoll | 'top' >> combine.Top.Largest(5) 139 result_bot = pcoll | 'bot' >> combine.Top.Smallest(4) 140 assert_that(result_top, equal_to([[9, 6, 6, 5, 3]]), label='assert:top') 141 assert_that(result_bot, equal_to([[0, 1, 1, 1]]), label='assert:bot') 142 143 # Now for global combines without default 144 timestamped = pcoll | Map(lambda x: TimestampedValue(x, timestamp)) 145 windowed = timestamped | 'window' >> WindowInto(FixedWindows(60)) 146 result_windowed_top = windowed | 'top-wo-defaults' >> combine.Top.Largest( 147 5, has_defaults=False) 148 result_windowed_bot = ( 149 windowed 150 | 'bot-wo-defaults' >> combine.Top.Smallest(4, has_defaults=False)) 151 assert_that( 152 result_windowed_top, 153 equal_to([[9, 6, 6, 5, 3]]), 154 label='assert:top-wo-defaults') 155 assert_that( 156 result_windowed_bot, 157 equal_to([[0, 1, 1, 1]]), 158 label='assert:bot-wo-defaults') 159 160 # Again for per-key combines. 161 pcoll = pipeline | 'start-perkey' >> Create( 162 [('a', x) for x in [6, 3, 1, 1, 9, 1, 5, 2, 0, 6]]) 163 result_key_top = pcoll | 'top-perkey' >> combine.Top.LargestPerKey(5) 164 result_key_bot = pcoll | 'bot-perkey' >> combine.Top.SmallestPerKey(4) 165 assert_that( 166 result_key_top, equal_to([('a', [9, 6, 6, 5, 3])]), label='key:top') 167 assert_that( 168 result_key_bot, equal_to([('a', [0, 1, 1, 1])]), label='key:bot') 169 170 def test_empty_global_top(self): 171 with TestPipeline() as p: 172 assert_that(p | beam.Create([]) | combine.Top.Largest(10), equal_to([[]])) 173 174 def test_sharded_top(self): 175 elements = list(range(100)) 176 random.shuffle(elements) 177 178 with TestPipeline() as pipeline: 179 shards = [ 180 pipeline | 'Shard%s' % shard >> beam.Create(elements[shard::7]) 181 for shard in range(7) 182 ] 183 assert_that( 184 shards | beam.Flatten() | combine.Top.Largest(10), 185 equal_to([[99, 98, 97, 96, 95, 94, 93, 92, 91, 90]])) 186 187 def test_top_key(self): 188 self.assertEqual(['aa', 'bbb', 'c', 'dddd'] | combine.Top.Of(3, key=len), 189 [['dddd', 'bbb', 'aa']]) 190 self.assertEqual(['aa', 'bbb', 'c', 'dddd'] 191 | combine.Top.Of(3, key=len, reverse=True), 192 [['c', 'aa', 'bbb']]) 193 194 self.assertEqual(['xc', 'zb', 'yd', 'wa'] 195 | combine.Top.Largest(3, key=lambda x: x[-1]), 196 [['yd', 'xc', 'zb']]) 197 self.assertEqual(['xc', 'zb', 'yd', 'wa'] 198 | combine.Top.Smallest(3, key=lambda x: x[-1]), 199 [['wa', 'zb', 'xc']]) 200 201 self.assertEqual([('a', x) for x in [1, 2, 3, 4, 1, 1]] 202 | combine.Top.LargestPerKey(3, key=lambda x: -x), 203 [('a', [1, 1, 1])]) 204 self.assertEqual([('a', x) for x in [1, 2, 3, 4, 1, 1]] 205 | combine.Top.SmallestPerKey(3, key=lambda x: -x), 206 [('a', [4, 3, 2])]) 207 208 def test_sharded_top_combine_fn(self): 209 def test_combine_fn(combine_fn, shards, expected): 210 accumulators = [ 211 combine_fn.add_inputs(combine_fn.create_accumulator(), shard) 212 for shard in shards 213 ] 214 final_accumulator = combine_fn.merge_accumulators(accumulators) 215 self.assertEqual(combine_fn.extract_output(final_accumulator), expected) 216 217 test_combine_fn(combine.TopCombineFn(3), [range(10), range(10)], [9, 9, 8]) 218 test_combine_fn( 219 combine.TopCombineFn(5), [range(1000), range(100), range(1001)], 220 [1000, 999, 999, 998, 998]) 221 222 def test_combine_per_key_top_display_data(self): 223 def individual_test_per_key_dd(combineFn): 224 transform = beam.CombinePerKey(combineFn) 225 dd = DisplayData.create_from(transform) 226 expected_items = [ 227 DisplayDataItemMatcher('combine_fn', combineFn.__class__), 228 DisplayDataItemMatcher('n', combineFn._n), 229 DisplayDataItemMatcher('compare', combineFn._compare.__name__) 230 ] 231 hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items)) 232 233 individual_test_per_key_dd(combine.Largest(5)) 234 individual_test_per_key_dd(combine.Smallest(3)) 235 individual_test_per_key_dd(combine.TopCombineFn(8)) 236 individual_test_per_key_dd(combine.Largest(5)) 237 238 def test_combine_sample_display_data(self): 239 def individual_test_per_key_dd(sampleFn, n): 240 trs = [sampleFn(n)] 241 for transform in trs: 242 dd = DisplayData.create_from(transform) 243 hc.assert_that( 244 dd.items, 245 hc.contains_inanyorder(DisplayDataItemMatcher('n', transform._n))) 246 247 individual_test_per_key_dd(combine.Sample.FixedSizePerKey, 5) 248 individual_test_per_key_dd(combine.Sample.FixedSizeGlobally, 5) 249 250 def test_combine_globally_display_data(self): 251 transform = beam.CombineGlobally(combine.Smallest(5)) 252 dd = DisplayData.create_from(transform) 253 expected_items = [ 254 DisplayDataItemMatcher('combine_fn', combine.Smallest), 255 DisplayDataItemMatcher('n', 5), 256 DisplayDataItemMatcher('compare', 'gt') 257 ] 258 hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items)) 259 260 def test_basic_combiners_display_data(self): 261 transform = beam.CombineGlobally( 262 combine.TupleCombineFn(max, combine.MeanCombineFn(), sum)) 263 dd = DisplayData.create_from(transform) 264 expected_items = [ 265 DisplayDataItemMatcher('combine_fn', combine.TupleCombineFn), 266 DisplayDataItemMatcher('combiners', "['max', 'MeanCombineFn', 'sum']"), 267 DisplayDataItemMatcher('merge_accumulators_batch_size', 333), 268 ] 269 hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items)) 270 271 def test_top_shorthands(self): 272 with TestPipeline() as pipeline: 273 274 pcoll = pipeline | 'start' >> Create([6, 3, 1, 1, 9, 1, 5, 2, 0, 6]) 275 result_top = pcoll | 'top' >> beam.CombineGlobally(combine.Largest(5)) 276 result_bot = pcoll | 'bot' >> beam.CombineGlobally(combine.Smallest(4)) 277 assert_that(result_top, equal_to([[9, 6, 6, 5, 3]]), label='assert:top') 278 assert_that(result_bot, equal_to([[0, 1, 1, 1]]), label='assert:bot') 279 280 pcoll = pipeline | 'start-perkey' >> Create( 281 [('a', x) for x in [6, 3, 1, 1, 9, 1, 5, 2, 0, 6]]) 282 result_ktop = pcoll | 'top-perkey' >> beam.CombinePerKey( 283 combine.Largest(5)) 284 result_kbot = pcoll | 'bot-perkey' >> beam.CombinePerKey( 285 combine.Smallest(4)) 286 assert_that(result_ktop, equal_to([('a', [9, 6, 6, 5, 3])]), label='ktop') 287 assert_that(result_kbot, equal_to([('a', [0, 1, 1, 1])]), label='kbot') 288 289 def test_top_no_compact(self): 290 class TopCombineFnNoCompact(combine.TopCombineFn): 291 def compact(self, accumulator): 292 return accumulator 293 294 with TestPipeline() as pipeline: 295 pcoll = pipeline | 'Start' >> Create([6, 3, 1, 1, 9, 1, 5, 2, 0, 6]) 296 result_top = pcoll | 'Top' >> beam.CombineGlobally( 297 TopCombineFnNoCompact(5, key=lambda x: x)) 298 result_bot = pcoll | 'Bot' >> beam.CombineGlobally( 299 TopCombineFnNoCompact(4, reverse=True)) 300 assert_that(result_top, equal_to([[9, 6, 6, 5, 3]]), label='Assert:Top') 301 assert_that(result_bot, equal_to([[0, 1, 1, 1]]), label='Assert:Bot') 302 303 pcoll = pipeline | 'Start-Perkey' >> Create( 304 [('a', x) for x in [6, 3, 1, 1, 9, 1, 5, 2, 0, 6]]) 305 result_ktop = pcoll | 'Top-PerKey' >> beam.CombinePerKey( 306 TopCombineFnNoCompact(5, key=lambda x: x)) 307 result_kbot = pcoll | 'Bot-PerKey' >> beam.CombinePerKey( 308 TopCombineFnNoCompact(4, reverse=True)) 309 assert_that(result_ktop, equal_to([('a', [9, 6, 6, 5, 3])]), label='KTop') 310 assert_that(result_kbot, equal_to([('a', [0, 1, 1, 1])]), label='KBot') 311 312 def test_global_sample(self): 313 def is_good_sample(actual): 314 assert len(actual) == 1 315 assert sorted(actual[0]) in [[1, 1, 2], [1, 2, 2]], actual 316 317 with TestPipeline() as pipeline: 318 timestamp = 0 319 pcoll = pipeline | 'start' >> Create([1, 1, 2, 2]) 320 321 # Now for global combines without default 322 timestamped = pcoll | Map(lambda x: TimestampedValue(x, timestamp)) 323 windowed = timestamped | 'window' >> WindowInto(FixedWindows(60)) 324 325 for ix in range(9): 326 assert_that( 327 pcoll | 'sample-%d' % ix >> combine.Sample.FixedSizeGlobally(3), 328 is_good_sample, 329 label='check-%d' % ix) 330 result_windowed = ( 331 windowed 332 | 'sample-wo-defaults-%d' % ix >> 333 combine.Sample.FixedSizeGlobally(3).without_defaults()) 334 assert_that( 335 result_windowed, is_good_sample, label='check-wo-defaults-%d' % ix) 336 337 def test_per_key_sample(self): 338 with TestPipeline() as pipeline: 339 pcoll = pipeline | 'start-perkey' >> Create( 340 sum(([(i, 1), (i, 1), (i, 2), (i, 2)] for i in range(9)), [])) 341 result = pcoll | 'sample' >> combine.Sample.FixedSizePerKey(3) 342 343 def matcher(): 344 def match(actual): 345 for _, samples in actual: 346 equal_to([3])([len(samples)]) 347 num_ones = sum(1 for x in samples if x == 1) 348 num_twos = sum(1 for x in samples if x == 2) 349 equal_to([1, 2])([num_ones, num_twos]) 350 351 return match 352 353 assert_that(result, matcher()) 354 355 def test_tuple_combine_fn(self): 356 with TestPipeline() as p: 357 result = ( 358 p 359 | Create([('a', 100, 0.0), ('b', 10, -1), ('c', 1, 100)]) 360 | beam.CombineGlobally( 361 combine.TupleCombineFn(max, combine.MeanCombineFn(), 362 sum)).without_defaults()) 363 assert_that(result, equal_to([('c', 111.0 / 3, 99.0)])) 364 365 def test_tuple_combine_fn_without_defaults(self): 366 with TestPipeline() as p: 367 result = ( 368 p 369 | Create([1, 1, 2, 3]) 370 | beam.CombineGlobally( 371 combine.TupleCombineFn( 372 min, combine.MeanCombineFn(), 373 max).with_common_input()).without_defaults()) 374 assert_that(result, equal_to([(1, 7.0 / 4, 3)])) 375 376 def test_empty_tuple_combine_fn(self): 377 with TestPipeline() as p: 378 result = ( 379 p 380 | Create([(), (), ()]) 381 | beam.CombineGlobally(combine.TupleCombineFn())) 382 assert_that(result, equal_to([()])) 383 384 def test_tuple_combine_fn_batched_merge(self): 385 num_combine_fns = 10 386 max_num_accumulators_in_memory = 30 387 # Maximum number of accumulator tuples in memory - 1 for the merge result. 388 merge_accumulators_batch_size = ( 389 max_num_accumulators_in_memory // num_combine_fns - 1) 390 num_accumulator_tuples_to_merge = 20 391 392 class CountedAccumulator: 393 count = 0 394 oom = False 395 396 def __init__(self): 397 if CountedAccumulator.count > max_num_accumulators_in_memory: 398 CountedAccumulator.oom = True 399 else: 400 CountedAccumulator.count += 1 401 402 class CountedAccumulatorCombineFn(beam.CombineFn): 403 def create_accumulator(self): 404 return CountedAccumulator() 405 406 def merge_accumulators(self, accumulators): 407 CountedAccumulator.count += 1 408 for _ in accumulators: 409 CountedAccumulator.count -= 1 410 411 combine_fn = combine.TupleCombineFn( 412 *[CountedAccumulatorCombineFn() for _ in range(num_combine_fns)], 413 merge_accumulators_batch_size=merge_accumulators_batch_size) 414 combine_fn.merge_accumulators( 415 combine_fn.create_accumulator() 416 for _ in range(num_accumulator_tuples_to_merge)) 417 assert not CountedAccumulator.oom 418 419 def test_to_list_and_to_dict1(self): 420 with TestPipeline() as pipeline: 421 the_list = [6, 3, 1, 1, 9, 1, 5, 2, 0, 6] 422 timestamp = 0 423 pcoll = pipeline | 'start' >> Create(the_list) 424 result = pcoll | 'to list' >> combine.ToList() 425 426 # Now for global combines without default 427 timestamped = pcoll | Map(lambda x: TimestampedValue(x, timestamp)) 428 windowed = timestamped | 'window' >> WindowInto(FixedWindows(60)) 429 result_windowed = ( 430 windowed 431 | 'to list wo defaults' >> combine.ToList().without_defaults()) 432 433 def matcher(expected): 434 def match(actual): 435 equal_to(expected[0])(actual[0]) 436 437 return match 438 439 assert_that(result, matcher([the_list])) 440 assert_that( 441 result_windowed, matcher([the_list]), label='to-list-wo-defaults') 442 443 def test_to_list_and_to_dict2(self): 444 with TestPipeline() as pipeline: 445 pairs = [(1, 2), (3, 4), (5, 6)] 446 timestamp = 0 447 pcoll = pipeline | 'start-pairs' >> Create(pairs) 448 result = pcoll | 'to dict' >> combine.ToDict() 449 450 # Now for global combines without default 451 timestamped = pcoll | Map(lambda x: TimestampedValue(x, timestamp)) 452 windowed = timestamped | 'window' >> WindowInto(FixedWindows(60)) 453 result_windowed = ( 454 windowed 455 | 'to dict wo defaults' >> combine.ToDict().without_defaults()) 456 457 def matcher(): 458 def match(actual): 459 equal_to([1])([len(actual)]) 460 equal_to(pairs)(actual[0].items()) 461 462 return match 463 464 assert_that(result, matcher()) 465 assert_that(result_windowed, matcher(), label='to-dict-wo-defaults') 466 467 def test_to_set(self): 468 pipeline = TestPipeline() 469 the_list = [6, 3, 1, 1, 9, 1, 5, 2, 0, 6] 470 timestamp = 0 471 pcoll = pipeline | 'start' >> Create(the_list) 472 result = pcoll | 'to set' >> combine.ToSet() 473 474 # Now for global combines without default 475 timestamped = pcoll | Map(lambda x: TimestampedValue(x, timestamp)) 476 windowed = timestamped | 'window' >> WindowInto(FixedWindows(60)) 477 result_windowed = ( 478 windowed 479 | 'to set wo defaults' >> combine.ToSet().without_defaults()) 480 481 def matcher(expected): 482 def match(actual): 483 equal_to(expected[0])(actual[0]) 484 485 return match 486 487 assert_that(result, matcher(set(the_list))) 488 assert_that( 489 result_windowed, matcher(set(the_list)), label='to-set-wo-defaults') 490 491 def test_combine_globally_with_default(self): 492 with TestPipeline() as p: 493 assert_that(p | Create([]) | CombineGlobally(sum), equal_to([0])) 494 495 def test_combine_globally_without_default(self): 496 with TestPipeline() as p: 497 result = p | Create([]) | CombineGlobally(sum).without_defaults() 498 assert_that(result, equal_to([])) 499 500 def test_combine_globally_with_default_side_input(self): 501 class SideInputCombine(PTransform): 502 def expand(self, pcoll): 503 side = pcoll | CombineGlobally(sum).as_singleton_view() 504 main = pcoll.pipeline | Create([None]) 505 return main | Map(lambda _, s: s, side) 506 507 with TestPipeline() as p: 508 result1 = p | 'i1' >> Create([]) | 'c1' >> SideInputCombine() 509 result2 = p | 'i2' >> Create([1, 2, 3, 4]) | 'c2' >> SideInputCombine() 510 assert_that(result1, equal_to([0]), label='r1') 511 assert_that(result2, equal_to([10]), label='r2') 512 513 def test_hot_key_fanout(self): 514 with TestPipeline() as p: 515 result = ( 516 p 517 | beam.Create(itertools.product(['hot', 'cold'], range(10))) 518 | beam.CombinePerKey(combine.MeanCombineFn()).with_hot_key_fanout( 519 lambda key: (key == 'hot') * 5)) 520 assert_that(result, equal_to([('hot', 4.5), ('cold', 4.5)])) 521 522 def test_hot_key_fanout_sharded(self): 523 # Lots of elements with the same key with varying/no fanout. 524 with TestPipeline() as p: 525 elements = [(None, e) for e in range(1000)] 526 random.shuffle(elements) 527 shards = [ 528 p | "Shard%s" % shard >> beam.Create(elements[shard::20]) 529 for shard in range(20) 530 ] 531 result = ( 532 shards 533 | beam.Flatten() 534 | beam.CombinePerKey(combine.MeanCombineFn()).with_hot_key_fanout( 535 lambda key: random.randrange(0, 5))) 536 assert_that(result, equal_to([(None, 499.5)])) 537 538 def test_global_fanout(self): 539 with TestPipeline() as p: 540 result = ( 541 p 542 | beam.Create(range(100)) 543 | beam.CombineGlobally(combine.MeanCombineFn()).with_fanout(11)) 544 assert_that(result, equal_to([49.5])) 545 546 def test_combining_with_accumulation_mode_and_fanout(self): 547 # PCollection will contain elements from 1 to 5. 548 elements = [i for i in range(1, 6)] 549 550 ts = TestStream().advance_watermark_to(0) 551 for i in elements: 552 ts.add_elements([i]) 553 ts.advance_watermark_to_infinity() 554 555 options = PipelineOptions() 556 options.view_as(StandardOptions).streaming = True 557 with TestPipeline(options=options) as p: 558 result = ( 559 p 560 | ts 561 | beam.WindowInto( 562 GlobalWindows(), 563 accumulation_mode=trigger.AccumulationMode.ACCUMULATING, 564 trigger=AfterWatermark(early=AfterAll(AfterCount(1)))) 565 | beam.CombineGlobally(sum).without_defaults().with_fanout(2)) 566 567 def has_expected_values(actual): 568 from hamcrest.core import assert_that as hamcrest_assert 569 from hamcrest.library.collection import contains 570 from hamcrest.library.collection import only_contains 571 ordered = sorted(actual) 572 # Early firings. 573 hamcrest_assert(ordered[:4], contains(1, 3, 6, 10)) 574 # Different runners have different number of 15s, but there should 575 # be at least one 15. 576 hamcrest_assert(ordered[4:], only_contains(15)) 577 578 assert_that(result, has_expected_values) 579 580 def test_combining_with_sliding_windows_and_fanout_raises_error(self): 581 options = PipelineOptions() 582 options.view_as(StandardOptions).streaming = True 583 with self.assertRaises(ValueError): 584 with TestPipeline(options=options) as p: 585 _ = ( 586 p 587 | beam.Create([ 588 window.TimestampedValue(0, Timestamp(seconds=1666707510)), 589 window.TimestampedValue(1, Timestamp(seconds=1666707511)), 590 window.TimestampedValue(2, Timestamp(seconds=1666707512)), 591 window.TimestampedValue(3, Timestamp(seconds=1666707513)), 592 window.TimestampedValue(5, Timestamp(seconds=1666707515)), 593 window.TimestampedValue(6, Timestamp(seconds=1666707516)), 594 window.TimestampedValue(7, Timestamp(seconds=1666707517)), 595 window.TimestampedValue(8, Timestamp(seconds=1666707518)) 596 ]) 597 | beam.WindowInto(window.SlidingWindows(10, 5)) 598 | beam.CombineGlobally(beam.combiners.ToListCombineFn()). 599 without_defaults().with_fanout(7)) 600 601 def test_MeanCombineFn_combine(self): 602 with TestPipeline() as p: 603 input = ( 604 p 605 | beam.Create([('a', 1), ('a', 1), ('a', 4), ('b', 1), ('b', 13)])) 606 # The mean of all values regardless of key. 607 global_mean = ( 608 input 609 | beam.Values() 610 | beam.CombineGlobally(combine.MeanCombineFn())) 611 612 # The (key, mean) pairs for all keys. 613 mean_per_key = (input | beam.CombinePerKey(combine.MeanCombineFn())) 614 615 expected_mean_per_key = [('a', 2), ('b', 7)] 616 assert_that(global_mean, equal_to([4]), label='global mean') 617 assert_that( 618 mean_per_key, equal_to(expected_mean_per_key), label='mean per key') 619 620 def test_MeanCombineFn_combine_empty(self): 621 # For each element in a PCollection, if it is float('NaN'), then emits 622 # a string 'NaN', otherwise emits str(element). 623 624 with TestPipeline() as p: 625 input = (p | beam.Create([])) 626 627 # Compute the mean of all values in the PCollection, 628 # then format the mean. Since the Pcollection is empty, 629 # the mean is float('NaN'), and is formatted to be a string 'NaN'. 630 global_mean = ( 631 input 632 | beam.Values() 633 | beam.CombineGlobally(combine.MeanCombineFn()) 634 | beam.Map(str)) 635 636 mean_per_key = (input | beam.CombinePerKey(combine.MeanCombineFn())) 637 638 # We can't compare one float('NaN') with another float('NaN'), 639 # but we can compare one 'nan' string with another string. 640 assert_that(global_mean, equal_to(['nan']), label='global mean') 641 assert_that(mean_per_key, equal_to([]), label='mean per key') 642 643 def test_sessions_combine(self): 644 with TestPipeline() as p: 645 input = ( 646 p 647 | beam.Create([('c', 1), ('c', 9), ('c', 12), ('d', 2), ('d', 4)]) 648 | beam.MapTuple(lambda k, v: window.TimestampedValue((k, v), v)) 649 | beam.WindowInto(window.Sessions(4))) 650 651 global_sum = ( 652 input 653 | beam.Values() 654 | beam.CombineGlobally(sum).without_defaults()) 655 sum_per_key = input | beam.CombinePerKey(sum) 656 657 # The first window has 3 elements: ('c', 1), ('d', 2), ('d', 4). 658 # The second window has 2 elements: ('c', 9), ('c', 12). 659 assert_that(global_sum, equal_to([7, 21]), label='global sum') 660 assert_that( 661 sum_per_key, 662 equal_to([('c', 1), ('c', 21), ('d', 6)]), 663 label='sum per key') 664 665 def test_fixed_windows_combine(self): 666 with TestPipeline() as p: 667 input = ( 668 p 669 | beam.Create([('c', 1), ('c', 2), ('c', 10), ('d', 5), ('d', 8), 670 ('d', 9)]) 671 | beam.MapTuple(lambda k, v: window.TimestampedValue((k, v), v)) 672 | beam.WindowInto(window.FixedWindows(4))) 673 674 global_sum = ( 675 input 676 | beam.Values() 677 | beam.CombineGlobally(sum).without_defaults()) 678 sum_per_key = input | beam.CombinePerKey(sum) 679 680 # The first window has 2 elements: ('c', 1), ('c', 2). 681 # The second window has 1 elements: ('d', 5). 682 # The third window has 3 elements: ('c', 10), ('d', 8), ('d', 9). 683 assert_that(global_sum, equal_to([3, 5, 27]), label='global sum') 684 assert_that( 685 sum_per_key, 686 equal_to([('c', 3), ('c', 10), ('d', 5), ('d', 17)]), 687 label='sum per key') 688 689 # Test that three different kinds of metrics work with a customized 690 # SortedConcatWithCounters CombineFn. 691 def test_custormized_counters_in_combine_fn(self): 692 p = TestPipeline() 693 input = ( 694 p 695 | beam.Create([('key1', 'a'), ('key1', 'ab'), ('key1', 'abc'), 696 ('key2', 'uvxy'), ('key2', 'uvxyz')])) 697 698 # The result of concatenating all values regardless of key. 699 global_concat = ( 700 input 701 | beam.Values() 702 | beam.CombineGlobally(SortedConcatWithCounters())) 703 704 # The (key, concatenated_string) pairs for all keys. 705 concat_per_key = (input | beam.CombinePerKey(SortedConcatWithCounters())) 706 707 # Verify the concatenated strings are correct. 708 expected_concat_per_key = [('key1', 'aaabbc'), ('key2', 'uuvvxxyyz')] 709 assert_that( 710 global_concat, equal_to(['aaabbcuuvvxxyyz']), label='global concat') 711 assert_that( 712 concat_per_key, 713 equal_to(expected_concat_per_key), 714 label='concat per key') 715 716 result = p.run() 717 result.wait_until_finish() 718 719 # Verify the values of metrics are correct. 720 word_counter_filter = MetricsFilter().with_name('word_counter') 721 query_result = result.metrics().query(word_counter_filter) 722 if query_result['counters']: 723 word_counter = query_result['counters'][0] 724 self.assertEqual(word_counter.result, 5) 725 726 word_lengths_filter = MetricsFilter().with_name('word_lengths') 727 query_result = result.metrics().query(word_lengths_filter) 728 if query_result['counters']: 729 word_lengths = query_result['counters'][0] 730 self.assertEqual(word_lengths.result, 15) 731 732 word_len_dist_filter = MetricsFilter().with_name('word_len_dist') 733 query_result = result.metrics().query(word_len_dist_filter) 734 if query_result['distributions']: 735 word_len_dist = query_result['distributions'][0] 736 self.assertEqual(word_len_dist.result.mean, 3) 737 738 last_word_len_filter = MetricsFilter().with_name('last_word_len') 739 query_result = result.metrics().query(last_word_len_filter) 740 if query_result['gauges']: 741 last_word_len = query_result['gauges'][0] 742 self.assertIn(last_word_len.result.value, [1, 2, 3, 4, 5]) 743 744 # Test that three different kinds of metrics work with the customized 745 # SortedConcatWithCounters CombineFn when the PCollection is empty. 746 def test_custormized_counters_in_combine_fn_empty(self): 747 p = TestPipeline() 748 input = p | beam.Create([]) 749 750 # The result of concatenating all values regardless of key. 751 global_concat = ( 752 input 753 | beam.Values() 754 | beam.CombineGlobally(SortedConcatWithCounters())) 755 756 # The (key, concatenated_string) pairs for all keys. 757 concat_per_key = (input | beam.CombinePerKey(SortedConcatWithCounters())) 758 759 # Verify the concatenated strings are correct. 760 assert_that(global_concat, equal_to(['']), label='global concat') 761 assert_that(concat_per_key, equal_to([]), label='concat per key') 762 763 result = p.run() 764 result.wait_until_finish() 765 766 # Verify the values of metrics are correct. 767 word_counter_filter = MetricsFilter().with_name('word_counter') 768 query_result = result.metrics().query(word_counter_filter) 769 if query_result['counters']: 770 word_counter = query_result['counters'][0] 771 self.assertEqual(word_counter.result, 0) 772 773 word_lengths_filter = MetricsFilter().with_name('word_lengths') 774 query_result = result.metrics().query(word_lengths_filter) 775 if query_result['counters']: 776 word_lengths = query_result['counters'][0] 777 self.assertEqual(word_lengths.result, 0) 778 779 word_len_dist_filter = MetricsFilter().with_name('word_len_dist') 780 query_result = result.metrics().query(word_len_dist_filter) 781 if query_result['distributions']: 782 word_len_dist = query_result['distributions'][0] 783 self.assertEqual(word_len_dist.result.count, 0) 784 785 last_word_len_filter = MetricsFilter().with_name('last_word_len') 786 query_result = result.metrics().query(last_word_len_filter) 787 788 # No element has ever been recorded. 789 self.assertFalse(query_result['gauges']) 790 791 792 class LatestTest(unittest.TestCase): 793 def test_globally(self): 794 l = [ 795 window.TimestampedValue(3, 100), 796 window.TimestampedValue(1, 200), 797 window.TimestampedValue(2, 300) 798 ] 799 with TestPipeline() as p: 800 # Map(lambda x: x) PTransform is added after Create here, because when 801 # a PCollection of TimestampedValues is created with Create PTransform, 802 # the timestamps are not assigned to it. Adding a Map forces the 803 # PCollection to go through a DoFn so that the PCollection consists of 804 # the elements with timestamps assigned to them instead of a PCollection 805 # of TimestampedValue(element, timestamp). 806 pcoll = p | Create(l) | Map(lambda x: x) 807 latest = pcoll | combine.Latest.Globally() 808 assert_that(latest, equal_to([2])) 809 810 # Now for global combines without default 811 windowed = pcoll | 'window' >> WindowInto(FixedWindows(180)) 812 result_windowed = ( 813 windowed 814 | 815 'latest wo defaults' >> combine.Latest.Globally().without_defaults()) 816 817 assert_that(result_windowed, equal_to([3, 2]), label='latest-wo-defaults') 818 819 def test_globally_empty(self): 820 l = [] 821 with TestPipeline() as p: 822 pc = p | Create(l) | Map(lambda x: x) 823 latest = pc | combine.Latest.Globally() 824 assert_that(latest, equal_to([None])) 825 826 def test_per_key(self): 827 l = [ 828 window.TimestampedValue(('a', 1), 300), 829 window.TimestampedValue(('b', 3), 100), 830 window.TimestampedValue(('a', 2), 200) 831 ] 832 with TestPipeline() as p: 833 pc = p | Create(l) | Map(lambda x: x) 834 latest = pc | combine.Latest.PerKey() 835 assert_that(latest, equal_to([('a', 1), ('b', 3)])) 836 837 def test_per_key_empty(self): 838 l = [] 839 with TestPipeline() as p: 840 pc = p | Create(l) | Map(lambda x: x) 841 latest = pc | combine.Latest.PerKey() 842 assert_that(latest, equal_to([])) 843 844 845 class LatestCombineFnTest(unittest.TestCase): 846 def setUp(self): 847 self.fn = combine.LatestCombineFn() 848 849 def test_create_accumulator(self): 850 accumulator = self.fn.create_accumulator() 851 self.assertEqual(accumulator, (None, window.MIN_TIMESTAMP)) 852 853 def test_add_input(self): 854 accumulator = self.fn.create_accumulator() 855 element = (1, 100) 856 new_accumulator = self.fn.add_input(accumulator, element) 857 self.assertEqual(new_accumulator, (1, 100)) 858 859 def test_merge_accumulators(self): 860 accumulators = [(2, 400), (5, 100), (9, 200)] 861 merged_accumulator = self.fn.merge_accumulators(accumulators) 862 self.assertEqual(merged_accumulator, (2, 400)) 863 864 def test_extract_output(self): 865 accumulator = (1, 100) 866 output = self.fn.extract_output(accumulator) 867 self.assertEqual(output, 1) 868 869 def test_with_input_types_decorator_violation(self): 870 l_int = [1, 2, 3] 871 l_dict = [{'a': 3}, {'g': 5}, {'r': 8}] 872 l_3_tuple = [(12, 31, 41), (12, 34, 34), (84, 92, 74)] 873 874 with self.assertRaises(TypeCheckError): 875 with TestPipeline() as p: 876 pc = p | Create(l_int) 877 _ = pc | beam.CombineGlobally(self.fn) 878 879 with self.assertRaises(TypeCheckError): 880 with TestPipeline() as p: 881 pc = p | Create(l_dict) 882 _ = pc | beam.CombineGlobally(self.fn) 883 884 with self.assertRaises(TypeCheckError): 885 with TestPipeline() as p: 886 pc = p | Create(l_3_tuple) 887 _ = pc | beam.CombineGlobally(self.fn) 888 889 890 @pytest.mark.it_validatesrunner 891 class CombineValuesTest(unittest.TestCase): 892 def test_gbk_immediately_followed_by_combine(self): 893 def merge(vals): 894 return "".join(vals) 895 896 with TestPipeline() as p: 897 result = ( 898 p \ 899 | Create([("key1", "foo"), ("key2", "bar"), ("key1", "foo")], 900 reshuffle=False) \ 901 | beam.GroupByKey() \ 902 | beam.CombineValues(merge) \ 903 | beam.MapTuple(lambda k, v: '{}: {}'.format(k, v))) 904 905 assert_that(result, equal_to(['key1: foofoo', 'key2: bar'])) 906 907 908 # 909 # Test cases for streaming. 910 # 911 @pytest.mark.it_validatesrunner 912 class TimestampCombinerTest(unittest.TestCase): 913 def test_combiner_earliest(self): 914 """Test TimestampCombiner with EARLIEST.""" 915 options = PipelineOptions(streaming=True) 916 with TestPipeline(options=options) as p: 917 result = ( 918 p 919 | TestStream().add_elements([window.TimestampedValue( 920 ('k', 100), 2)]).add_elements( 921 [window.TimestampedValue( 922 ('k', 400), 7)]).advance_watermark_to_infinity() 923 | beam.WindowInto( 924 window.FixedWindows(10), 925 timestamp_combiner=TimestampCombiner.OUTPUT_AT_EARLIEST) 926 | beam.CombinePerKey(sum)) 927 928 records = ( 929 result 930 | beam.Map(lambda e, ts=beam.DoFn.TimestampParam: (e, ts))) 931 932 # All the KV pairs are applied GBK using EARLIEST timestamp for the same 933 # key. 934 expected_window_to_elements = { 935 window.IntervalWindow(0, 10): [ 936 (('k', 500), Timestamp(2)), 937 ], 938 } 939 940 assert_that( 941 records, 942 equal_to_per_window(expected_window_to_elements), 943 use_global_window=False, 944 label='assert per window') 945 946 def test_combiner_latest(self): 947 """Test TimestampCombiner with LATEST.""" 948 options = PipelineOptions(streaming=True) 949 with TestPipeline(options=options) as p: 950 result = ( 951 p 952 | TestStream().add_elements([window.TimestampedValue( 953 ('k', 100), 2)]).add_elements( 954 [window.TimestampedValue( 955 ('k', 400), 7)]).advance_watermark_to_infinity() 956 | beam.WindowInto( 957 window.FixedWindows(10), 958 timestamp_combiner=TimestampCombiner.OUTPUT_AT_LATEST) 959 | beam.CombinePerKey(sum)) 960 961 records = ( 962 result 963 | beam.Map(lambda e, ts=beam.DoFn.TimestampParam: (e, ts))) 964 965 # All the KV pairs are applied GBK using LATEST timestamp for 966 # the same key. 967 expected_window_to_elements = { 968 window.IntervalWindow(0, 10): [ 969 (('k', 500), Timestamp(7)), 970 ], 971 } 972 973 assert_that( 974 records, 975 equal_to_per_window(expected_window_to_elements), 976 use_global_window=False, 977 label='assert per window') 978 979 980 if __name__ == '__main__': 981 unittest.main()