github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/transforms/trigger_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Unit tests for the triggering classes.""" 19 20 # pytype: skip-file 21 22 import collections 23 import json 24 import os.path 25 import pickle 26 import random 27 import unittest 28 29 import yaml 30 31 import apache_beam as beam 32 from apache_beam import coders 33 from apache_beam.options.pipeline_options import PipelineOptions 34 from apache_beam.options.pipeline_options import StandardOptions 35 from apache_beam.options.pipeline_options import TypeOptions 36 from apache_beam.portability import common_urns 37 from apache_beam.runners import pipeline_context 38 from apache_beam.runners.direct.clock import TestClock 39 from apache_beam.testing.test_pipeline import TestPipeline 40 from apache_beam.testing.test_stream import TestStream 41 from apache_beam.testing.util import assert_that 42 from apache_beam.testing.util import equal_to 43 from apache_beam.transforms import WindowInto 44 from apache_beam.transforms import ptransform 45 from apache_beam.transforms import trigger 46 from apache_beam.transforms.core import Windowing 47 from apache_beam.transforms.trigger import AccumulationMode 48 from apache_beam.transforms.trigger import AfterAll 49 from apache_beam.transforms.trigger import AfterAny 50 from apache_beam.transforms.trigger import AfterCount 51 from apache_beam.transforms.trigger import AfterEach 52 from apache_beam.transforms.trigger import AfterProcessingTime 53 from apache_beam.transforms.trigger import AfterWatermark 54 from apache_beam.transforms.trigger import Always 55 from apache_beam.transforms.trigger import DataLossReason 56 from apache_beam.transforms.trigger import DefaultTrigger 57 from apache_beam.transforms.trigger import GeneralTriggerDriver 58 from apache_beam.transforms.trigger import InMemoryUnmergedState 59 from apache_beam.transforms.trigger import Repeatedly 60 from apache_beam.transforms.trigger import TriggerFn 61 from apache_beam.transforms.trigger import _Never 62 from apache_beam.transforms.window import FixedWindows 63 from apache_beam.transforms.window import GlobalWindows 64 from apache_beam.transforms.window import IntervalWindow 65 from apache_beam.transforms.window import Sessions 66 from apache_beam.transforms.window import TimestampCombiner 67 from apache_beam.transforms.window import TimestampedValue 68 from apache_beam.transforms.window import WindowedValue 69 from apache_beam.transforms.window import WindowFn 70 from apache_beam.utils.timestamp import MAX_TIMESTAMP 71 from apache_beam.utils.timestamp import MIN_TIMESTAMP 72 from apache_beam.utils.timestamp import Duration 73 from apache_beam.utils.windowed_value import PaneInfoTiming 74 75 76 class CustomTimestampingFixedWindowsWindowFn(FixedWindows): 77 """WindowFn for testing custom timestamping.""" 78 def get_transformed_output_time(self, unused_window, input_timestamp): 79 return input_timestamp + 100 80 81 82 class TriggerTest(unittest.TestCase): 83 def run_trigger_simple( 84 self, 85 window_fn, 86 trigger_fn, 87 accumulation_mode, 88 timestamped_data, 89 expected_panes, 90 *groupings, 91 **kwargs): 92 # Groupings is a list of integers indicating the (uniform) size of bundles 93 # to try. For example, if timestamped_data has elements [a, b, c, d, e] 94 # then groupings=(5, 2) would first run the test with everything in the same 95 # bundle, and then re-run the test with bundling [a, b], [c, d], [e]. 96 # A negative value will reverse the order, e.g. -2 would result in bundles 97 # [e, d], [c, b], [a]. This is useful for deterministic triggers in testing 98 # that the output is not a function of ordering or bundling. 99 # If empty, defaults to bundles of size 1 in the given order. 100 late_data = kwargs.pop('late_data', []) 101 assert not kwargs 102 103 def bundle_data(data, size): 104 if size < 0: 105 data = list(data)[::-1] 106 size = -size 107 bundle = [] 108 for timestamp, elem in data: 109 windows = window_fn.assign(WindowFn.AssignContext(timestamp, elem)) 110 bundle.append(WindowedValue(elem, timestamp, windows)) 111 if len(bundle) == size: 112 yield bundle 113 bundle = [] 114 if bundle: 115 yield bundle 116 117 if not groupings: 118 groupings = [1] 119 for group_by in groupings: 120 self.run_trigger( 121 window_fn, 122 trigger_fn, 123 accumulation_mode, 124 bundle_data(timestamped_data, group_by), 125 bundle_data(late_data, group_by), 126 expected_panes) 127 128 def run_trigger( 129 self, 130 window_fn, 131 trigger_fn, 132 accumulation_mode, 133 bundles, 134 late_bundles, 135 expected_panes): 136 actual_panes = collections.defaultdict(list) 137 allowed_lateness = Duration( 138 micros=int(common_urns.constants.MAX_TIMESTAMP_MILLIS.constant) * 1000) 139 driver = GeneralTriggerDriver( 140 Windowing( 141 window_fn, 142 trigger_fn, 143 accumulation_mode, 144 allowed_lateness=allowed_lateness), 145 TestClock()) 146 state = InMemoryUnmergedState() 147 148 for bundle in bundles: 149 for wvalue in driver.process_elements(state, 150 bundle, 151 MIN_TIMESTAMP, 152 MIN_TIMESTAMP): 153 window, = wvalue.windows 154 self.assertEqual(window.max_timestamp(), wvalue.timestamp) 155 actual_panes[window].append(set(wvalue.value)) 156 157 while state.timers: 158 for timer_window, (name, time_domain, timestamp, 159 _) in state.get_and_clear_timers(): 160 for wvalue in driver.process_timer(timer_window, 161 name, 162 time_domain, 163 timestamp, 164 state, 165 MIN_TIMESTAMP): 166 window, = wvalue.windows 167 self.assertEqual(window.max_timestamp(), wvalue.timestamp) 168 actual_panes[window].append(set(wvalue.value)) 169 170 for bundle in late_bundles: 171 for wvalue in driver.process_elements(state, 172 bundle, 173 MAX_TIMESTAMP, 174 MAX_TIMESTAMP): 175 window, = wvalue.windows 176 self.assertEqual(window.max_timestamp(), wvalue.timestamp) 177 actual_panes[window].append(set(wvalue.value)) 178 179 while state.timers: 180 for timer_window, (name, time_domain, timestamp, 181 _) in state.get_and_clear_timers(): 182 for wvalue in driver.process_timer(timer_window, 183 name, 184 time_domain, 185 timestamp, 186 state, 187 MAX_TIMESTAMP): 188 window, = wvalue.windows 189 self.assertEqual(window.max_timestamp(), wvalue.timestamp) 190 actual_panes[window].append(set(wvalue.value)) 191 192 self.assertEqual(expected_panes, actual_panes) 193 194 def test_fixed_watermark(self): 195 self.run_trigger_simple( 196 FixedWindows(10), # pyformat break 197 AfterWatermark(), 198 AccumulationMode.ACCUMULATING, 199 [(1, 'a'), (2, 'b'), (13, 'c')], 200 {IntervalWindow(0, 10): [set('ab')], 201 IntervalWindow(10, 20): [set('c')]}, 202 1, 203 2, 204 3, 205 -3, 206 -2, 207 -1) 208 209 def test_fixed_watermark_with_early(self): 210 self.run_trigger_simple( 211 FixedWindows(10), # pyformat break 212 AfterWatermark(early=AfterCount(2)), 213 AccumulationMode.ACCUMULATING, 214 [(1, 'a'), (2, 'b'), (3, 'c')], 215 {IntervalWindow(0, 10): [set('ab'), set('abc')]}, 216 2) 217 self.run_trigger_simple( 218 FixedWindows(10), # pyformat break 219 AfterWatermark(early=AfterCount(2)), 220 AccumulationMode.ACCUMULATING, 221 [(1, 'a'), (2, 'b'), (3, 'c')], 222 {IntervalWindow(0, 10): [set('abc'), set('abc')]}, 223 3) 224 225 def test_fixed_watermark_with_early_late(self): 226 self.run_trigger_simple( 227 FixedWindows(100), # pyformat break 228 AfterWatermark(early=AfterCount(3), 229 late=AfterCount(2)), 230 AccumulationMode.DISCARDING, 231 zip(range(9), 'abcdefghi'), 232 {IntervalWindow(0, 100): [ 233 set('abcd'), set('efgh'), # early 234 set('i'), # on time 235 set('vw'), set('xy') # late 236 ]}, 237 2, 238 late_data=zip(range(5), 'vwxyz')) 239 240 def test_sessions_watermark_with_early_late(self): 241 self.run_trigger_simple( 242 Sessions(10), # pyformat break 243 AfterWatermark(early=AfterCount(2), 244 late=AfterCount(1)), 245 AccumulationMode.ACCUMULATING, 246 [(1, 'a'), (15, 'b'), (7, 'c'), (30, 'd')], 247 { 248 IntervalWindow(1, 25): [ 249 set('abc'), # early 250 set('abc'), # on time 251 set('abcxy') # late 252 ], 253 IntervalWindow(30, 40): [ 254 set('d'), # on time 255 ], 256 IntervalWindow(1, 40): [ 257 set('abcdxyz') # late 258 ], 259 }, 260 2, 261 late_data=[(1, 'x'), (2, 'y'), (21, 'z')]) 262 263 def test_fixed_after_count(self): 264 self.run_trigger_simple( 265 FixedWindows(10), # pyformat break 266 AfterCount(2), 267 AccumulationMode.ACCUMULATING, 268 [(1, 'a'), (2, 'b'), (3, 'c'), (11, 'z')], 269 {IntervalWindow(0, 10): [set('ab')]}, 270 1, 271 2) 272 self.run_trigger_simple( 273 FixedWindows(10), # pyformat break 274 AfterCount(2), 275 AccumulationMode.ACCUMULATING, 276 [(1, 'a'), (2, 'b'), (3, 'c'), (11, 'z')], 277 {IntervalWindow(0, 10): [set('abc')]}, 278 3, 279 4) 280 281 def test_fixed_after_first(self): 282 self.run_trigger_simple( 283 FixedWindows(10), # pyformat break 284 AfterAny(AfterCount(2), AfterWatermark()), 285 AccumulationMode.ACCUMULATING, 286 [(1, 'a'), (2, 'b'), (3, 'c')], 287 {IntervalWindow(0, 10): [set('ab')]}, 288 1, 289 2) 290 self.run_trigger_simple( 291 FixedWindows(10), # pyformat break 292 AfterAny(AfterCount(5), AfterWatermark()), 293 AccumulationMode.ACCUMULATING, 294 [(1, 'a'), (2, 'b'), (3, 'c')], 295 {IntervalWindow(0, 10): [set('abc')]}, 296 1, 297 2, 298 late_data=[(1, 'x'), (2, 'y'), (3, 'z')]) 299 300 def test_repeatedly_after_first(self): 301 self.run_trigger_simple( 302 FixedWindows(100), # pyformat break 303 Repeatedly(AfterAny(AfterCount(3), AfterWatermark())), 304 AccumulationMode.ACCUMULATING, 305 zip(range(7), 'abcdefg'), 306 {IntervalWindow(0, 100): [ 307 set('abc'), 308 set('abcdef'), 309 set('abcdefg'), 310 set('abcdefgx'), 311 set('abcdefgxy'), 312 set('abcdefgxyz')]}, 313 1, 314 late_data=zip(range(3), 'xyz')) 315 316 def test_sessions_after_all(self): 317 self.run_trigger_simple( 318 Sessions(10), # pyformat break 319 AfterAll(AfterCount(2), AfterWatermark()), 320 AccumulationMode.ACCUMULATING, 321 [(1, 'a'), (2, 'b'), (3, 'c')], 322 {IntervalWindow(1, 13): [set('abc')]}, 323 1, 324 2) 325 self.run_trigger_simple( 326 Sessions(10), # pyformat break 327 AfterAll(AfterCount(5), AfterWatermark()), 328 AccumulationMode.ACCUMULATING, 329 [(1, 'a'), (2, 'b'), (3, 'c')], 330 {IntervalWindow(1, 13): [set('abcxy')]}, 331 1, 332 2, 333 late_data=[(1, 'x'), (2, 'y'), (3, 'z')]) 334 335 def test_sessions_default(self): 336 self.run_trigger_simple( 337 Sessions(10), # pyformat break 338 DefaultTrigger(), 339 AccumulationMode.ACCUMULATING, 340 [(1, 'a'), (2, 'b')], 341 {IntervalWindow(1, 12): [set('ab')]}, 342 1, 343 2, 344 -2, 345 -1) 346 347 self.run_trigger_simple( 348 Sessions(10), # pyformat break 349 AfterWatermark(), 350 AccumulationMode.ACCUMULATING, 351 [(1, 'a'), (2, 'b'), (15, 'c'), (16, 'd'), (30, 'z'), (9, 'e'), 352 (10, 'f'), (30, 'y')], 353 {IntervalWindow(1, 26): [set('abcdef')], 354 IntervalWindow(30, 40): [set('yz')]}, 355 1, 356 2, 357 3, 358 4, 359 5, 360 6, 361 -4, 362 -2, 363 -1) 364 365 def test_sessions_watermark(self): 366 self.run_trigger_simple( 367 Sessions(10), # pyformat break 368 AfterWatermark(), 369 AccumulationMode.ACCUMULATING, 370 [(1, 'a'), (2, 'b')], 371 {IntervalWindow(1, 12): [set('ab')]}, 372 1, 373 2, 374 -2, 375 -1) 376 377 def test_sessions_after_count(self): 378 self.run_trigger_simple( 379 Sessions(10), # pyformat break 380 AfterCount(2), 381 AccumulationMode.ACCUMULATING, 382 [(1, 'a'), (15, 'b'), (6, 'c'), (30, 's'), (31, 't'), (50, 'z'), 383 (50, 'y')], 384 {IntervalWindow(1, 25): [set('abc')], 385 IntervalWindow(30, 41): [set('st')], 386 IntervalWindow(50, 60): [set('yz')]}, 387 1, 388 2, 389 3) 390 391 def test_sessions_repeatedly_after_count(self): 392 self.run_trigger_simple( 393 Sessions(10), # pyformat break 394 Repeatedly(AfterCount(2)), 395 AccumulationMode.ACCUMULATING, 396 [(1, 'a'), (15, 'b'), (6, 'c'), (2, 'd'), (7, 'e')], 397 {IntervalWindow(1, 25): [set('abc'), set('abcde')]}, 398 1, 399 3) 400 self.run_trigger_simple( 401 Sessions(10), # pyformat break 402 Repeatedly(AfterCount(2)), 403 AccumulationMode.DISCARDING, 404 [(1, 'a'), (15, 'b'), (6, 'c'), (2, 'd'), (7, 'e')], 405 {IntervalWindow(1, 25): [set('abc'), set('de')]}, 406 1, 407 3) 408 409 def test_sessions_after_each(self): 410 self.run_trigger_simple( 411 Sessions(10), # pyformat break 412 AfterEach(AfterCount(2), AfterCount(3)), 413 AccumulationMode.ACCUMULATING, 414 zip(range(10), 'abcdefghij'), 415 {IntervalWindow(0, 11): [set('ab')], 416 IntervalWindow(0, 15): [set('abcdef')]}, 417 2) 418 419 self.run_trigger_simple( 420 Sessions(10), # pyformat break 421 Repeatedly(AfterEach(AfterCount(2), AfterCount(3))), 422 AccumulationMode.ACCUMULATING, 423 zip(range(10), 'abcdefghij'), 424 {IntervalWindow(0, 11): [set('ab')], 425 IntervalWindow(0, 15): [set('abcdef')], 426 IntervalWindow(0, 17): [set('abcdefgh')]}, 427 2) 428 429 def test_picklable_output(self): 430 global_window = (trigger.GlobalWindow(), ) 431 driver = trigger.BatchGlobalTriggerDriver() 432 unpicklable = (WindowedValue(k, 0, global_window) for k in range(10)) 433 with self.assertRaises(TypeError): 434 pickle.dumps(unpicklable) 435 for unwindowed in driver.process_elements(None, unpicklable, None, None): 436 self.assertEqual( 437 pickle.loads(pickle.dumps(unwindowed)).value, list(range(10))) 438 439 440 class MayLoseDataTest(unittest.TestCase): 441 def _test(self, trigger, lateness, expected): 442 windowing = WindowInto( 443 GlobalWindows(), 444 trigger=trigger, 445 accumulation_mode=AccumulationMode.ACCUMULATING, 446 allowed_lateness=lateness).windowing 447 self.assertEqual(trigger.may_lose_data(windowing), expected) 448 449 def test_default_trigger(self): 450 self._test(DefaultTrigger(), 0, DataLossReason.NO_POTENTIAL_LOSS) 451 452 def test_after_processing(self): 453 self._test(AfterProcessingTime(42), 0, DataLossReason.MAY_FINISH) 454 455 def test_always(self): 456 self._test(Always(), 0, DataLossReason.NO_POTENTIAL_LOSS) 457 458 def test_never(self): 459 self._test(_Never(), 0, DataLossReason.NO_POTENTIAL_LOSS) 460 461 def test_after_watermark_no_allowed_lateness(self): 462 self._test(AfterWatermark(), 0, DataLossReason.NO_POTENTIAL_LOSS) 463 464 def test_after_watermark_no_late_trigger(self): 465 self._test(AfterWatermark(), 60, DataLossReason.MAY_FINISH) 466 467 def test_after_watermark_no_allowed_lateness_safe_late(self): 468 self._test( 469 AfterWatermark(late=DefaultTrigger()), 470 0, 471 DataLossReason.NO_POTENTIAL_LOSS) 472 473 def test_after_watermark_allowed_lateness_safe_late(self): 474 self._test( 475 AfterWatermark(late=DefaultTrigger()), 476 60, 477 DataLossReason.NO_POTENTIAL_LOSS) 478 479 def test_after_count(self): 480 self._test(AfterCount(42), 0, DataLossReason.MAY_FINISH) 481 482 def test_repeatedly_safe_underlying(self): 483 self._test( 484 Repeatedly(DefaultTrigger()), 0, DataLossReason.NO_POTENTIAL_LOSS) 485 486 def test_repeatedly_unsafe_underlying(self): 487 self._test(Repeatedly(AfterCount(42)), 0, DataLossReason.NO_POTENTIAL_LOSS) 488 489 def test_after_any_one_may_finish(self): 490 self._test( 491 AfterAny(AfterCount(42), DefaultTrigger()), 492 0, 493 DataLossReason.MAY_FINISH) 494 495 def test_after_any_all_safe(self): 496 self._test( 497 AfterAny(Repeatedly(AfterCount(42)), DefaultTrigger()), 498 0, 499 DataLossReason.NO_POTENTIAL_LOSS) 500 501 def test_after_all_some_may_finish(self): 502 self._test( 503 AfterAll(AfterCount(1), DefaultTrigger()), 504 0, 505 DataLossReason.NO_POTENTIAL_LOSS) 506 507 def test_afer_all_all_may_finish(self): 508 self._test( 509 AfterAll(AfterCount(42), AfterProcessingTime(42)), 510 0, 511 DataLossReason.MAY_FINISH) 512 513 def test_after_each_at_least_one_safe(self): 514 self._test( 515 AfterEach(AfterCount(1), DefaultTrigger(), AfterCount(2)), 516 0, 517 DataLossReason.NO_POTENTIAL_LOSS) 518 519 def test_after_each_all_may_finish(self): 520 self._test( 521 AfterEach(AfterCount(1), AfterCount(2), AfterCount(3)), 522 0, 523 DataLossReason.MAY_FINISH) 524 525 526 class RunnerApiTest(unittest.TestCase): 527 def test_trigger_encoding(self): 528 for trigger_fn in (DefaultTrigger(), 529 AfterAll(AfterCount(1), AfterCount(10)), 530 AfterAny(AfterCount(10), AfterCount(100)), 531 AfterWatermark(early=AfterCount(1000)), 532 AfterWatermark(early=AfterCount(1000), 533 late=AfterCount(1)), 534 Repeatedly(AfterCount(100)), 535 trigger.OrFinally(AfterCount(3), AfterCount(10))): 536 context = pipeline_context.PipelineContext() 537 self.assertEqual( 538 trigger_fn, 539 TriggerFn.from_runner_api(trigger_fn.to_runner_api(context), context)) 540 541 542 class TriggerPipelineTest(unittest.TestCase): 543 def test_after_processing_time(self): 544 test_options = PipelineOptions( 545 flags=['--allow_unsafe_triggers', '--streaming']) 546 with TestPipeline(options=test_options) as p: 547 548 total_elements_in_trigger = 4 549 processing_time_delay = 2 550 window_size = 10 551 552 # yapf: disable 553 test_stream = TestStream() 554 for i in range(total_elements_in_trigger): 555 (test_stream 556 .advance_processing_time( 557 processing_time_delay / total_elements_in_trigger) 558 .add_elements([('key', i)]) 559 ) 560 561 test_stream.advance_processing_time(processing_time_delay) 562 563 # Add dropped elements 564 (test_stream 565 .advance_processing_time(0.1) 566 .add_elements([('key', "dropped-1")]) 567 .advance_processing_time(0.1) 568 .add_elements([('key', "dropped-2")]) 569 ) 570 571 (test_stream 572 .advance_processing_time(processing_time_delay) 573 .advance_watermark_to_infinity() 574 ) 575 # yapf: enable 576 577 results = ( 578 p 579 | test_stream 580 | beam.WindowInto( 581 FixedWindows(window_size), 582 trigger=AfterProcessingTime(processing_time_delay), 583 accumulation_mode=AccumulationMode.DISCARDING) 584 | beam.GroupByKey() 585 | beam.Map(lambda x: x[1])) 586 587 assert_that(results, equal_to([list(range(total_elements_in_trigger))])) 588 589 def test_repeatedly_after_processing_time(self): 590 test_options = PipelineOptions(flags=['--streaming']) 591 with TestPipeline(options=test_options) as p: 592 total_elements = 7 593 processing_time_delay = 2 594 window_size = 10 595 # yapf: disable 596 test_stream = TestStream() 597 for i in range(total_elements): 598 (test_stream 599 .advance_processing_time(processing_time_delay - 0.01) 600 .add_elements([('key', i)]) 601 ) 602 603 (test_stream 604 .advance_processing_time(processing_time_delay) 605 .advance_watermark_to_infinity() 606 ) 607 # yapf: enable 608 609 results = ( 610 p 611 | test_stream 612 | beam.WindowInto( 613 FixedWindows(window_size), 614 trigger=Repeatedly(AfterProcessingTime(processing_time_delay)), 615 accumulation_mode=AccumulationMode.DISCARDING) 616 | beam.GroupByKey() 617 | beam.Map(lambda x: x[1])) 618 619 expected = [[i, i + 1] 620 for i in range(total_elements - total_elements % 2) 621 if i % 2 == 0] 622 expected += [] if total_elements % 2 == 0 else [[total_elements - 1]] 623 624 assert_that(results, equal_to(expected)) 625 626 def test_after_count(self): 627 test_options = PipelineOptions(flags=['--allow_unsafe_triggers']) 628 with TestPipeline(options=test_options) as p: 629 630 def construct_timestamped(k, t): 631 return TimestampedValue((k, t), t) 632 633 def format_result(k, vs): 634 return ('%s-%s' % (k, len(list(vs))), set(vs)) 635 636 result = ( 637 p 638 | beam.Create([1, 2, 3, 4, 5, 10, 11]) 639 | beam.FlatMap(lambda t: [('A', t), ('B', t + 5)]) 640 | beam.MapTuple(construct_timestamped) 641 | beam.WindowInto( 642 FixedWindows(10), 643 trigger=AfterCount(3), 644 accumulation_mode=AccumulationMode.DISCARDING) 645 | beam.GroupByKey() 646 | beam.MapTuple(format_result)) 647 assert_that( 648 result, 649 equal_to( 650 list({ 651 'A-5': {1, 2, 3, 4, 5}, 652 # A-10, A-11 never emitted due to AfterCount(3) never firing. 653 'B-4': {6, 7, 8, 9}, 654 'B-3': {10, 15, 16}, 655 }.items()))) 656 657 def test_after_count_streaming(self): 658 test_options = PipelineOptions( 659 flags=['--allow_unsafe_triggers', '--streaming']) 660 with TestPipeline(options=test_options) as p: 661 # yapf: disable 662 test_stream = ( 663 TestStream() 664 .advance_watermark_to(0) 665 .add_elements([('A', 1), ('A', 2), ('A', 3)]) 666 .add_elements([('A', 4), ('A', 5), ('A', 6)]) 667 .add_elements([('B', 1), ('B', 2), ('B', 3)]) 668 .advance_watermark_to_infinity()) 669 # yapf: enable 670 671 results = ( 672 p 673 | test_stream 674 | beam.WindowInto( 675 FixedWindows(10), 676 trigger=AfterCount(3), 677 accumulation_mode=AccumulationMode.ACCUMULATING) 678 | beam.GroupByKey()) 679 680 assert_that( 681 results, 682 equal_to(list({ 683 'A': [1, 2, 3], # 4 - 6 discarded because trigger finished 684 'B': [1, 2, 3]}.items()))) 685 686 def test_always(self): 687 with TestPipeline() as p: 688 689 def construct_timestamped(k, t): 690 return TimestampedValue((k, t), t) 691 692 def format_result(k, vs): 693 return ('%s-%s' % (k, len(list(vs))), set(vs)) 694 695 result = ( 696 p 697 | beam.Create([1, 1, 2, 3, 4, 5, 10, 11]) 698 | beam.FlatMap(lambda t: [('A', t), ('B', t + 5)]) 699 | beam.MapTuple(construct_timestamped) 700 | beam.WindowInto( 701 FixedWindows(10), 702 trigger=Always(), 703 accumulation_mode=AccumulationMode.DISCARDING) 704 | beam.GroupByKey() 705 | beam.MapTuple(format_result)) 706 assert_that( 707 result, 708 equal_to( 709 list({ 710 'A-2': {10, 11}, 711 # Elements out of windows are also emitted. 712 'A-6': {1, 2, 3, 4, 5}, 713 # A,1 is emitted twice. 714 'B-5': {6, 7, 8, 9}, 715 # B,6 is emitted twice. 716 'B-3': {10, 15, 16}, 717 }.items()))) 718 719 def test_never(self): 720 with TestPipeline() as p: 721 722 def construct_timestamped(k, t): 723 return TimestampedValue((k, t), t) 724 725 def format_result(k, vs): 726 return ('%s-%s' % (k, len(list(vs))), set(vs)) 727 728 result = ( 729 p 730 | beam.Create([1, 1, 2, 3, 4, 5, 10, 11]) 731 | beam.FlatMap(lambda t: [('A', t), ('B', t + 5)]) 732 | beam.MapTuple(construct_timestamped) 733 | beam.WindowInto( 734 FixedWindows(10), 735 trigger=_Never(), 736 accumulation_mode=AccumulationMode.DISCARDING) 737 | beam.GroupByKey() 738 | beam.MapTuple(format_result)) 739 assert_that( 740 result, 741 equal_to( 742 list({ 743 'A-2': {10, 11}, 744 'A-6': {1, 2, 3, 4, 5}, 745 'B-5': {6, 7, 8, 9}, 746 'B-3': {10, 15, 16}, 747 }.items()))) 748 749 def test_multiple_accumulating_firings(self): 750 # PCollection will contain elements from 1 to 10. 751 elements = [i for i in range(1, 11)] 752 753 ts = TestStream().advance_watermark_to(0) 754 for i in elements: 755 ts.add_elements([('key', str(i))]) 756 if i % 5 == 0: 757 ts.advance_watermark_to(i) 758 ts.advance_processing_time(5) 759 ts.advance_watermark_to_infinity() 760 761 options = PipelineOptions() 762 options.view_as(StandardOptions).streaming = True 763 with TestPipeline(options=options) as p: 764 records = ( 765 p 766 | ts 767 | beam.WindowInto( 768 FixedWindows(10), 769 accumulation_mode=trigger.AccumulationMode.ACCUMULATING, 770 trigger=AfterWatermark( 771 early=AfterAll(AfterCount(1), AfterProcessingTime(5)))) 772 | beam.GroupByKey() 773 | beam.FlatMap(lambda x: x[1])) 774 775 # The trigger should fire twice. Once after 5 seconds, and once after 10. 776 # The firings should accumulate the output. 777 first_firing = [str(i) for i in elements if i <= 5] 778 second_firing = [str(i) for i in elements] 779 assert_that(records, equal_to(first_firing + second_firing)) 780 781 def test_on_pane_watermark_hold_no_pipeline_stall(self): 782 """A regression test added for 783 ttps://issues.apache.org/jira/browse/BEAM-10054.""" 784 START_TIMESTAMP = 1534842000 785 786 test_stream = TestStream() 787 test_stream.add_elements(['a']) 788 test_stream.advance_processing_time(START_TIMESTAMP + 1) 789 test_stream.advance_watermark_to(START_TIMESTAMP + 1) 790 test_stream.add_elements(['b']) 791 test_stream.advance_processing_time(START_TIMESTAMP + 2) 792 test_stream.advance_watermark_to(START_TIMESTAMP + 2) 793 794 with TestPipeline(options=PipelineOptions( 795 ['--streaming', '--allow_unsafe_triggers'])) as p: 796 # pylint: disable=expression-not-assigned 797 ( 798 p 799 | 'TestStream' >> test_stream 800 | 'timestamp' >> 801 beam.Map(lambda x: beam.window.TimestampedValue(x, START_TIMESTAMP)) 802 | 'kv' >> beam.Map(lambda x: (x, x)) 803 | 'window_1m' >> beam.WindowInto( 804 beam.window.FixedWindows(60), 805 trigger=trigger.AfterAny( 806 trigger.AfterProcessingTime(3600), trigger.AfterWatermark()), 807 accumulation_mode=trigger.AccumulationMode.DISCARDING) 808 | 'group_by_key' >> beam.GroupByKey() 809 | 'filter' >> beam.Map(lambda x: x)) 810 811 812 class TranscriptTest(unittest.TestCase): 813 814 # We must prepend an underscore to this name so that the open-source unittest 815 # runner does not execute this method directly as a test. 816 @classmethod 817 def _create_test(cls, spec): 818 counter = 0 819 name = spec.get('name', 'unnamed') 820 unique_name = 'test_' + name 821 while hasattr(cls, unique_name): 822 counter += 1 823 unique_name = 'test_%s_%d' % (name, counter) 824 test_method = lambda self: self._run_log_test(spec) 825 test_method.__name__ = unique_name 826 test_method.__test__ = True 827 setattr(cls, unique_name, test_method) 828 829 # We must prepend an underscore to this name so that the open-source unittest 830 # runner does not execute this method directly as a test. 831 @classmethod 832 def _create_tests(cls, transcript_filename): 833 for spec in yaml.load_all(open(transcript_filename), 834 Loader=yaml.SafeLoader): 835 cls._create_test(spec) 836 837 def _run_log_test(self, spec): 838 if 'error' in spec: 839 self.assertRaisesRegex(Exception, spec['error'], self._run_log, spec) 840 else: 841 self._run_log(spec) 842 843 def _run_log(self, spec): 844 def parse_int_list(s): 845 """Parses strings like '[1, 2, 3]'.""" 846 s = s.strip() 847 assert s[0] == '[' and s[-1] == ']', s 848 if not s[1:-1].strip(): 849 return [] 850 return [int(x) for x in s[1:-1].split(',')] 851 852 def split_args(s): 853 """Splits 'a, b, [c, d]' into ['a', 'b', '[c, d]'].""" 854 args = [] 855 start = 0 856 depth = 0 857 for ix in range(len(s)): 858 c = s[ix] 859 if c in '({[': 860 depth += 1 861 elif c in ')}]': 862 depth -= 1 863 elif c == ',' and depth == 0: 864 args.append(s[start:ix].strip()) 865 start = ix + 1 866 assert depth == 0, s 867 args.append(s[start:].strip()) 868 return args 869 870 def parse(s, names): 871 """Parse (recursive) 'Foo(arg, kw=arg)' for Foo in the names dict.""" 872 s = s.strip() 873 if s in names: 874 return names[s] 875 elif s[0] == '[': 876 return parse_int_list(s) 877 elif '(' in s: 878 assert s[-1] == ')', s 879 callee = parse(s[:s.index('(')], names) 880 posargs = [] 881 kwargs = {} 882 for arg in split_args(s[s.index('(') + 1:-1]): 883 if '=' in arg: 884 kw, value = arg.split('=', 1) 885 kwargs[kw] = parse(value, names) 886 else: 887 posargs.append(parse(arg, names)) 888 return callee(*posargs, **kwargs) 889 else: 890 try: 891 return int(s) 892 except ValueError: 893 raise ValueError('Unknown function: %s' % s) 894 895 def parse_fn(s, names): 896 """Like parse(), but implicitly calls no-arg constructors.""" 897 fn = parse(s, names) 898 if isinstance(fn, type): 899 return fn() 900 return fn 901 902 # pylint: disable=wrong-import-order, wrong-import-position 903 from apache_beam.transforms import window as window_module 904 # pylint: enable=wrong-import-order, wrong-import-position 905 window_fn_names = dict(window_module.__dict__) 906 # yapf: disable 907 window_fn_names.update({ 908 'CustomTimestampingFixedWindowsWindowFn': 909 CustomTimestampingFixedWindowsWindowFn 910 }) 911 # yapf: enable 912 trigger_names = {'Default': DefaultTrigger} 913 trigger_names.update(trigger.__dict__) 914 915 window_fn = parse_fn( 916 spec.get('window_fn', 'GlobalWindows'), window_fn_names) 917 trigger_fn = parse_fn(spec.get('trigger_fn', 'Default'), trigger_names) 918 accumulation_mode = getattr( 919 AccumulationMode, spec.get('accumulation_mode', 'ACCUMULATING').upper()) 920 timestamp_combiner = getattr( 921 TimestampCombiner, 922 spec.get('timestamp_combiner', 'OUTPUT_AT_EOW').upper()) 923 allowed_lateness = spec.get('allowed_lateness', 0.000) 924 925 def only_element(xs): 926 x, = list(xs) 927 return x 928 929 transcript = [only_element(line.items()) for line in spec['transcript']] 930 931 self._execute( 932 window_fn, 933 trigger_fn, 934 accumulation_mode, 935 timestamp_combiner, 936 allowed_lateness, 937 transcript, 938 spec) 939 940 941 def _windowed_value_info(windowed_value): 942 # Currently some runners operate at the millisecond level, and some at the 943 # microsecond level. Trigger transcript timestamps are expressed as 944 # integral units of the finest granularity, whatever that may be. 945 # In these tests we interpret them as integral seconds and then truncate 946 # the results to integral seconds to allow for portability across 947 # different sub-second resolutions. 948 window, = windowed_value.windows 949 return { 950 'window': [int(window.start), int(window.max_timestamp())], 951 'values': sorted(windowed_value.value), 952 'timestamp': int(windowed_value.timestamp), 953 'index': windowed_value.pane_info.index, 954 'nonspeculative_index': windowed_value.pane_info.nonspeculative_index, 955 'early': windowed_value.pane_info.timing == PaneInfoTiming.EARLY, 956 'late': windowed_value.pane_info.timing == PaneInfoTiming.LATE, 957 'final': windowed_value.pane_info.is_last, 958 } 959 960 961 def _windowed_value_info_map_fn( 962 k, 963 vs, 964 window=beam.DoFn.WindowParam, 965 t=beam.DoFn.TimestampParam, 966 p=beam.DoFn.PaneInfoParam): 967 return ( 968 k, 969 _windowed_value_info( 970 WindowedValue(vs, windows=[window], timestamp=t, pane_info=p))) 971 972 973 def _windowed_value_info_check(actual, expected, key=None): 974 975 key_string = ' for %s' % key if key else '' 976 977 def format(panes): 978 return '\n[%s]\n' % '\n '.join( 979 str(pane) 980 for pane in sorted(panes, key=lambda pane: pane.get('timestamp', None))) 981 982 if len(actual) > len(expected): 983 raise AssertionError( 984 'Unexpected output%s: expected %s but got %s' % 985 (key_string, format(expected), format(actual))) 986 elif len(expected) > len(actual): 987 raise AssertionError( 988 'Unmatched output%s: expected %s but got %s' % 989 (key_string, format(expected), format(actual))) 990 else: 991 992 def diff(actual, expected): 993 for key in sorted(expected.keys(), reverse=True): 994 if key in actual: 995 if actual[key] != expected[key]: 996 return key 997 998 for output in actual: 999 diffs = [diff(output, pane) for pane in expected] 1000 if all(diffs): 1001 raise AssertionError( 1002 'Unmatched output%s: %s not found in %s (diffs in %s)' % 1003 (key_string, output, format(expected), diffs)) 1004 1005 1006 class _ConcatCombineFn(beam.CombineFn): 1007 create_accumulator = lambda self: [] # type: ignore[var-annotated] 1008 add_input = lambda self, acc, element: acc.append(element) or acc 1009 merge_accumulators = lambda self, accs: sum(accs, []) # type: ignore[var-annotated] 1010 extract_output = lambda self, acc: acc 1011 1012 1013 class TriggerDriverTranscriptTest(TranscriptTest): 1014 def _execute( 1015 self, 1016 window_fn, 1017 trigger_fn, 1018 accumulation_mode, 1019 timestamp_combiner, 1020 allowed_lateness, 1021 transcript, 1022 unused_spec): 1023 1024 driver = GeneralTriggerDriver( 1025 Windowing( 1026 window_fn, 1027 trigger_fn, 1028 accumulation_mode, 1029 timestamp_combiner, 1030 allowed_lateness), 1031 TestClock()) 1032 state = InMemoryUnmergedState() 1033 output = [] 1034 watermark = MIN_TIMESTAMP 1035 1036 def fire_timers(): 1037 to_fire = state.get_and_clear_timers(watermark) 1038 while to_fire: 1039 for timer_window, (name, time_domain, t_timestamp, _) in to_fire: 1040 for wvalue in driver.process_timer(timer_window, 1041 name, 1042 time_domain, 1043 t_timestamp, 1044 state): 1045 output.append(_windowed_value_info(wvalue)) 1046 to_fire = state.get_and_clear_timers(watermark) 1047 1048 for action, params in transcript: 1049 1050 if action != 'expect': 1051 # Fail if we have output that was not expected in the transcript. 1052 self.assertEqual([], 1053 output, 1054 msg='Unexpected output: %s before %s: %s' % 1055 (output, action, params)) 1056 1057 if action == 'input': 1058 bundle = [ 1059 WindowedValue(t, t, window_fn.assign(WindowFn.AssignContext(t, t))) 1060 for t in params 1061 ] 1062 output = [ 1063 _windowed_value_info(wv) for wv in driver.process_elements( 1064 state, bundle, watermark, watermark) 1065 ] 1066 fire_timers() 1067 1068 elif action == 'watermark': 1069 watermark = params 1070 fire_timers() 1071 1072 elif action == 'expect': 1073 for expected_output in params: 1074 for candidate in output: 1075 if all(candidate[k] == expected_output[k] for k in candidate 1076 if k in expected_output): 1077 output.remove(candidate) 1078 break 1079 else: 1080 self.fail('Unmatched output %s in %s' % (expected_output, output)) 1081 1082 elif action == 'state': 1083 # TODO(robertwb): Implement once we support allowed lateness. 1084 pass 1085 1086 else: 1087 self.fail('Unknown action: ' + action) 1088 1089 # Fail if we have output that was not expected in the transcript. 1090 self.assertEqual([], output, msg='Unexpected output: %s' % output) 1091 1092 1093 class BaseTestStreamTranscriptTest(TranscriptTest): 1094 """A suite of TestStream-based tests based on trigger transcript entries. 1095 """ 1096 def _execute( 1097 self, 1098 window_fn, 1099 trigger_fn, 1100 accumulation_mode, 1101 timestamp_combiner, 1102 allowed_lateness, 1103 transcript, 1104 spec): 1105 1106 runner_name = TestPipeline().runner.__class__.__name__ 1107 if runner_name in spec.get('broken_on', ()): 1108 self.skipTest('Known to be broken on %s' % runner_name) 1109 1110 is_order_agnostic = ( 1111 isinstance(trigger_fn, DefaultTrigger) and 1112 accumulation_mode == AccumulationMode.ACCUMULATING) 1113 1114 if is_order_agnostic: 1115 reshuffle_seed = random.randrange(1 << 20) 1116 keys = [ 1117 u'original', 1118 u'reversed', 1119 u'reshuffled(%s)' % reshuffle_seed, 1120 u'one-element-bundles', 1121 u'one-element-bundles-reversed', 1122 u'two-element-bundles' 1123 ] 1124 else: 1125 keys = [u'key1', u'key2'] 1126 1127 # Elements are encoded as a json strings to allow other languages to 1128 # decode elements while executing the test stream. 1129 # TODO(https://github.com/apache/beam/issues/19934): Eliminate these 1130 # gymnastics. 1131 test_stream = TestStream(coder=coders.StrUtf8Coder()).with_output_types(str) 1132 for action, params in transcript: 1133 if action == 'expect': 1134 test_stream.add_elements([json.dumps(('expect', params))]) 1135 else: 1136 test_stream.add_elements([json.dumps(('expect', []))]) 1137 if action == 'input': 1138 1139 def keyed(key, values): 1140 return [json.dumps(('input', (key, v))) for v in values] 1141 1142 if is_order_agnostic: 1143 # Must match keys above. 1144 test_stream.add_elements(keyed('original', params)) 1145 test_stream.add_elements(keyed('reversed', reversed(params))) 1146 r = random.Random(reshuffle_seed) 1147 reshuffled = list(params) 1148 r.shuffle(reshuffled) 1149 test_stream.add_elements( 1150 keyed('reshuffled(%s)' % reshuffle_seed, reshuffled)) 1151 for v in params: 1152 test_stream.add_elements(keyed('one-element-bundles', [v])) 1153 for v in reversed(params): 1154 test_stream.add_elements( 1155 keyed('one-element-bundles-reversed', [v])) 1156 for ix in range(0, len(params), 2): 1157 test_stream.add_elements( 1158 keyed('two-element-bundles', params[ix:ix + 2])) 1159 else: 1160 for key in keys: 1161 test_stream.add_elements(keyed(key, params)) 1162 elif action == 'watermark': 1163 test_stream.advance_watermark_to(params) 1164 elif action == 'clock': 1165 test_stream.advance_processing_time(params) 1166 elif action == 'state': 1167 pass # Requires inspection of implementation details. 1168 else: 1169 raise ValueError('Unexpected action: %s' % action) 1170 test_stream.add_elements([json.dumps(('expect', []))]) 1171 test_stream.advance_watermark_to_infinity() 1172 1173 read_test_stream = test_stream | beam.Map(json.loads) 1174 1175 class Check(beam.DoFn): 1176 """A StatefulDoFn that verifies outputs are produced as expected. 1177 1178 This DoFn takes in two kinds of inputs, actual outputs and 1179 expected outputs. When an actual output is received, it is buffered 1180 into state, and when an expected output is received, this buffered 1181 state is retrieved and compared against the expected value(s) to ensure 1182 they match. 1183 1184 The key is ignored, but all items must be on the same key to share state. 1185 """ 1186 def __init__(self, allow_out_of_order=True): 1187 # Some runners don't support cross-stage TestStream semantics. 1188 self.allow_out_of_order = allow_out_of_order 1189 1190 def process( 1191 self, 1192 element, 1193 seen=beam.DoFn.StateParam( 1194 beam.transforms.userstate.BagStateSpec( 1195 'seen', beam.coders.FastPrimitivesCoder())), 1196 expected=beam.DoFn.StateParam( 1197 beam.transforms.userstate.BagStateSpec( 1198 'expected', beam.coders.FastPrimitivesCoder()))): 1199 key, (action, data) = element 1200 1201 if self.allow_out_of_order: 1202 if action == 'expect' and not list(seen.read()): 1203 if data: 1204 expected.add(data) 1205 return 1206 elif action == 'actual' and list(expected.read()): 1207 seen.add(data) 1208 all_data = list(seen.read()) 1209 all_expected = list(expected.read()) 1210 if len(all_data) == len(all_expected[0]): 1211 expected.clear() 1212 for expect in all_expected[1:]: 1213 expected.add(expect) 1214 action, data = 'expect', all_expected[0] 1215 else: 1216 return 1217 1218 if action == 'actual': 1219 seen.add(data) 1220 1221 elif action == 'expect': 1222 actual = list(seen.read()) 1223 seen.clear() 1224 _windowed_value_info_check(actual, data, key) 1225 1226 else: 1227 raise ValueError('Unexpected action: %s' % action) 1228 1229 @ptransform.ptransform_fn 1230 def CheckAggregation(inputs_and_expected, aggregation): 1231 # Split the test stream into a branch of to-be-processed elements, and 1232 # a branch of expected results. 1233 inputs, expected = ( 1234 inputs_and_expected 1235 | beam.MapTuple( 1236 lambda tag, value: beam.pvalue.TaggedOutput(tag, value), 1237 ).with_outputs('input', 'expect')) 1238 1239 # Process the inputs with the given windowing to produce actual outputs. 1240 outputs = ( 1241 inputs 1242 | beam.MapTuple( 1243 lambda key, value: TimestampedValue((key, value), value)) 1244 | beam.WindowInto( 1245 window_fn, 1246 trigger=trigger_fn, 1247 accumulation_mode=accumulation_mode, 1248 timestamp_combiner=timestamp_combiner, 1249 allowed_lateness=allowed_lateness) 1250 | aggregation 1251 | beam.MapTuple(_windowed_value_info_map_fn) 1252 # Place outputs back into the global window to allow flattening 1253 # and share a single state in Check. 1254 | 'Global' >> beam.WindowInto(beam.transforms.window.GlobalWindows())) 1255 # Feed both the expected and actual outputs to Check() for comparison. 1256 tagged_expected = ( 1257 expected | beam.FlatMap( 1258 lambda value: [(key, ('expect', value)) for key in keys])) 1259 tagged_outputs = ( 1260 outputs | beam.MapTuple(lambda key, value: (key, ('actual', value)))) 1261 # pylint: disable=expression-not-assigned 1262 ([tagged_expected, tagged_outputs] 1263 | beam.Flatten() 1264 | beam.ParDo(Check(self.allow_out_of_order))) 1265 1266 with TestPipeline() as p: 1267 # TODO(https://github.com/apache/beam/issues/19933): Pass this during 1268 # pipeline construction. 1269 p._options.view_as(StandardOptions).streaming = True 1270 p._options.view_as(TypeOptions).allow_unsafe_triggers = True 1271 1272 # We can have at most one test stream per pipeline, so we share it. 1273 inputs_and_expected = p | read_test_stream 1274 _ = inputs_and_expected | CheckAggregation(beam.GroupByKey()) 1275 _ = inputs_and_expected | CheckAggregation( 1276 beam.CombinePerKey(_ConcatCombineFn())) 1277 1278 1279 class TestStreamTranscriptTest(BaseTestStreamTranscriptTest): 1280 allow_out_of_order = False 1281 1282 1283 class WeakTestStreamTranscriptTest(BaseTestStreamTranscriptTest): 1284 allow_out_of_order = True 1285 1286 1287 class BatchTranscriptTest(TranscriptTest): 1288 def _execute( 1289 self, 1290 window_fn, 1291 trigger_fn, 1292 accumulation_mode, 1293 timestamp_combiner, 1294 allowed_lateness, 1295 transcript, 1296 spec): 1297 if timestamp_combiner == TimestampCombiner.OUTPUT_AT_EARLIEST_TRANSFORMED: 1298 self.skipTest( 1299 'Non-fnapi timestamp combiner: %s' % spec.get('timestamp_combiner')) 1300 1301 if accumulation_mode != AccumulationMode.ACCUMULATING: 1302 self.skipTest('Batch mode only makes sense for accumulating.') 1303 1304 watermark = MIN_TIMESTAMP 1305 for action, params in transcript: 1306 if action == 'watermark': 1307 watermark = params 1308 elif action == 'input': 1309 if any(t <= watermark for t in params): 1310 self.skipTest('Batch mode never has late data.') 1311 1312 inputs = sum([vs for action, vs in transcript if action == 'input'], []) 1313 final_panes_by_window = {} 1314 for action, params in transcript: 1315 if action == 'expect': 1316 for expected in params: 1317 trimmed = {} 1318 for field in ('window', 'values', 'timestamp'): 1319 if field in expected: 1320 trimmed[field] = expected[field] 1321 final_panes_by_window[tuple(expected['window'])] = trimmed 1322 final_panes = list(final_panes_by_window.values()) 1323 1324 if window_fn.is_merging(): 1325 merged_away = set() 1326 1327 class MergeContext(WindowFn.MergeContext): 1328 def merge(_, to_be_merged, merge_result): 1329 for window in to_be_merged: 1330 if window != merge_result: 1331 merged_away.add(window) 1332 1333 all_windows = [IntervalWindow(*pane['window']) for pane in final_panes] 1334 window_fn.merge(MergeContext(all_windows)) 1335 final_panes = [ 1336 pane for pane in final_panes 1337 if IntervalWindow(*pane['window']) not in merged_away 1338 ] 1339 1340 with TestPipeline() as p: 1341 input_pc = ( 1342 p 1343 | beam.Create(inputs) 1344 | beam.Map(lambda t: TimestampedValue(('key', t), t)) 1345 | beam.WindowInto( 1346 window_fn, 1347 trigger=trigger_fn, 1348 accumulation_mode=accumulation_mode, 1349 timestamp_combiner=timestamp_combiner, 1350 allowed_lateness=allowed_lateness)) 1351 1352 grouped = input_pc | 'Grouped' >> ( 1353 beam.GroupByKey() 1354 | beam.MapTuple(_windowed_value_info_map_fn) 1355 | beam.MapTuple(lambda _, value: value)) 1356 1357 combined = input_pc | 'Combined' >> ( 1358 beam.CombinePerKey(_ConcatCombineFn()) 1359 | beam.MapTuple(_windowed_value_info_map_fn) 1360 | beam.MapTuple(lambda _, value: value)) 1361 1362 assert_that( 1363 grouped, 1364 lambda actual: _windowed_value_info_check(actual, final_panes), 1365 label='CheckGrouped') 1366 1367 assert_that( 1368 combined, 1369 lambda actual: _windowed_value_info_check(actual, final_panes), 1370 label='CheckCombined') 1371 1372 1373 TRANSCRIPT_TEST_FILE = os.path.join( 1374 os.path.dirname(__file__), 1375 '..', 1376 'testing', 1377 'data', 1378 'trigger_transcripts.yaml') 1379 if os.path.exists(TRANSCRIPT_TEST_FILE): 1380 TriggerDriverTranscriptTest._create_tests(TRANSCRIPT_TEST_FILE) 1381 TestStreamTranscriptTest._create_tests(TRANSCRIPT_TEST_FILE) 1382 WeakTestStreamTranscriptTest._create_tests(TRANSCRIPT_TEST_FILE) 1383 BatchTranscriptTest._create_tests(TRANSCRIPT_TEST_FILE) 1384 1385 if __name__ == '__main__': 1386 unittest.main()