github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/transforms/util_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Unit tests for the transform.util classes.""" 19 20 # pytype: skip-file 21 22 import logging 23 import math 24 import random 25 import re 26 import time 27 import unittest 28 import warnings 29 from datetime import datetime 30 31 import pytest 32 import pytz 33 34 import apache_beam as beam 35 from apache_beam import GroupByKey 36 from apache_beam import Map 37 from apache_beam import WindowInto 38 from apache_beam.coders import coders 39 from apache_beam.metrics import MetricsFilter 40 from apache_beam.options.pipeline_options import PipelineOptions 41 from apache_beam.options.pipeline_options import StandardOptions 42 from apache_beam.portability import common_urns 43 from apache_beam.portability.api import beam_runner_api_pb2 44 from apache_beam.pvalue import AsList 45 from apache_beam.pvalue import AsSingleton 46 from apache_beam.runners import pipeline_context 47 from apache_beam.testing.test_pipeline import TestPipeline 48 from apache_beam.testing.test_stream import TestStream 49 from apache_beam.testing.util import SortLists 50 from apache_beam.testing.util import TestWindowedValue 51 from apache_beam.testing.util import assert_that 52 from apache_beam.testing.util import contains_in_any_order 53 from apache_beam.testing.util import equal_to 54 from apache_beam.transforms import trigger 55 from apache_beam.transforms import util 56 from apache_beam.transforms import window 57 from apache_beam.transforms.core import FlatMapTuple 58 from apache_beam.transforms.trigger import AfterCount 59 from apache_beam.transforms.trigger import Repeatedly 60 from apache_beam.transforms.window import FixedWindows 61 from apache_beam.transforms.window import GlobalWindow 62 from apache_beam.transforms.window import GlobalWindows 63 from apache_beam.transforms.window import IntervalWindow 64 from apache_beam.transforms.window import Sessions 65 from apache_beam.transforms.window import SlidingWindows 66 from apache_beam.transforms.window import TimestampedValue 67 from apache_beam.typehints import typehints 68 from apache_beam.typehints.sharded_key_type import ShardedKeyType 69 from apache_beam.utils import proto_utils 70 from apache_beam.utils import timestamp 71 from apache_beam.utils.timestamp import MAX_TIMESTAMP 72 from apache_beam.utils.timestamp import MIN_TIMESTAMP 73 from apache_beam.utils.windowed_value import WindowedValue 74 75 warnings.filterwarnings( 76 'ignore', category=FutureWarning, module='apache_beam.transform.util_test') 77 78 79 class CoGroupByKeyTest(unittest.TestCase): 80 def test_co_group_by_key_on_tuple(self): 81 with TestPipeline() as pipeline: 82 pcoll_1 = pipeline | 'Start 1' >> beam.Create([('a', 1), ('a', 2), 83 ('b', 3), ('c', 4)]) 84 pcoll_2 = pipeline | 'Start 2' >> beam.Create([('a', 5), ('a', 6), 85 ('c', 7), ('c', 8)]) 86 result = (pcoll_1, pcoll_2) | beam.CoGroupByKey() | SortLists 87 assert_that( 88 result, 89 equal_to([('a', ([1, 2], [5, 6])), ('b', ([3], [])), 90 ('c', ([4], [7, 8]))])) 91 92 def test_co_group_by_key_on_iterable(self): 93 with TestPipeline() as pipeline: 94 pcoll_1 = pipeline | 'Start 1' >> beam.Create([('a', 1), ('a', 2), 95 ('b', 3), ('c', 4)]) 96 pcoll_2 = pipeline | 'Start 2' >> beam.Create([('a', 5), ('a', 6), 97 ('c', 7), ('c', 8)]) 98 result = iter([pcoll_1, pcoll_2]) | beam.CoGroupByKey() | SortLists 99 assert_that( 100 result, 101 equal_to([('a', ([1, 2], [5, 6])), ('b', ([3], [])), 102 ('c', ([4], [7, 8]))])) 103 104 def test_co_group_by_key_on_list(self): 105 with TestPipeline() as pipeline: 106 pcoll_1 = pipeline | 'Start 1' >> beam.Create([('a', 1), ('a', 2), 107 ('b', 3), ('c', 4)]) 108 pcoll_2 = pipeline | 'Start 2' >> beam.Create([('a', 5), ('a', 6), 109 ('c', 7), ('c', 8)]) 110 result = [pcoll_1, pcoll_2] | beam.CoGroupByKey() | SortLists 111 assert_that( 112 result, 113 equal_to([('a', ([1, 2], [5, 6])), ('b', ([3], [])), 114 ('c', ([4], [7, 8]))])) 115 116 def test_co_group_by_key_on_dict(self): 117 with TestPipeline() as pipeline: 118 pcoll_1 = pipeline | 'Start 1' >> beam.Create([('a', 1), ('a', 2), 119 ('b', 3), ('c', 4)]) 120 pcoll_2 = pipeline | 'Start 2' >> beam.Create([('a', 5), ('a', 6), 121 ('c', 7), ('c', 8)]) 122 result = {'X': pcoll_1, 'Y': pcoll_2} | beam.CoGroupByKey() | SortLists 123 assert_that( 124 result, 125 equal_to([('a', { 126 'X': [1, 2], 'Y': [5, 6] 127 }), ('b', { 128 'X': [3], 'Y': [] 129 }), ('c', { 130 'X': [4], 'Y': [7, 8] 131 })])) 132 133 def test_co_group_by_key_on_dict_with_tuple_keys(self): 134 with TestPipeline() as pipeline: 135 key = ('a', ('b', 'c')) 136 pcoll_1 = pipeline | 'Start 1' >> beam.Create([(key, 1)]) 137 pcoll_2 = pipeline | 'Start 2' >> beam.Create([(key, 2)]) 138 result = {'X': pcoll_1, 'Y': pcoll_2} | beam.CoGroupByKey() | SortLists 139 assert_that(result, equal_to([(key, {'X': [1], 'Y': [2]})])) 140 141 def test_co_group_by_key_on_empty(self): 142 with TestPipeline() as pipeline: 143 assert_that( 144 tuple() | 'EmptyTuple' >> beam.CoGroupByKey(pipeline=pipeline), 145 equal_to([]), 146 label='AssertEmptyTuple') 147 assert_that([] | 'EmptyList' >> beam.CoGroupByKey(pipeline=pipeline), 148 equal_to([]), 149 label='AssertEmptyList') 150 assert_that( 151 iter([]) | 'EmptyIterable' >> beam.CoGroupByKey(pipeline=pipeline), 152 equal_to([]), 153 label='AssertEmptyIterable') 154 assert_that({} | 'EmptyDict' >> beam.CoGroupByKey(pipeline=pipeline), 155 equal_to([]), 156 label='AssertEmptyDict') 157 158 def test_co_group_by_key_on_one(self): 159 with TestPipeline() as pipeline: 160 pcoll = pipeline | beam.Create([('a', 1), ('b', 2)]) 161 expected = [('a', ([1], )), ('b', ([2], ))] 162 assert_that((pcoll, ) | 'OneTuple' >> beam.CoGroupByKey(), 163 equal_to(expected), 164 label='AssertOneTuple') 165 assert_that([pcoll] | 'OneList' >> beam.CoGroupByKey(), 166 equal_to(expected), 167 label='AssertOneList') 168 assert_that( 169 iter([pcoll]) | 'OneIterable' >> beam.CoGroupByKey(), 170 equal_to(expected), 171 label='AssertOneIterable') 172 assert_that({'tag': pcoll} 173 | 'OneDict' >> beam.CoGroupByKey() 174 | beam.MapTuple(lambda k, v: (k, (v['tag'], ))), 175 equal_to(expected), 176 label='AssertOneDict') 177 178 179 class FakeClock(object): 180 def __init__(self, now=time.time()): 181 self._now = now 182 183 def __call__(self): 184 return self._now 185 186 def sleep(self, duration): 187 self._now += duration 188 189 190 class BatchElementsTest(unittest.TestCase): 191 def test_constant_batch(self): 192 # Assumes a single bundle... 193 p = TestPipeline() 194 output = ( 195 p 196 | beam.Create(range(35)) 197 | util.BatchElements(min_batch_size=10, max_batch_size=10) 198 | beam.Map(len)) 199 assert_that(output, equal_to([10, 10, 10, 5])) 200 res = p.run() 201 res.wait_until_finish() 202 metrics = res.metrics() 203 results = metrics.query(MetricsFilter().with_name("batch_size")) 204 self.assertEqual(len(results["distributions"]), 1) 205 206 def test_constant_batch_no_metrics(self): 207 p = TestPipeline() 208 output = ( 209 p 210 | beam.Create(range(35)) 211 | util.BatchElements( 212 min_batch_size=10, max_batch_size=10, record_metrics=False) 213 | beam.Map(len)) 214 assert_that(output, equal_to([10, 10, 10, 5])) 215 res = p.run() 216 res.wait_until_finish() 217 metrics = res.metrics() 218 results = metrics.query(MetricsFilter().with_name("batch_size")) 219 self.assertEqual(len(results["distributions"]), 0) 220 221 def test_grows_to_max_batch(self): 222 # Assumes a single bundle... 223 with TestPipeline() as p: 224 res = ( 225 p 226 | beam.Create(range(164)) 227 | util.BatchElements( 228 min_batch_size=1, max_batch_size=50, clock=FakeClock()) 229 | beam.Map(len)) 230 assert_that(res, equal_to([1, 1, 2, 4, 8, 16, 32, 50, 50])) 231 232 def test_windowed_batches(self): 233 # Assumes a single bundle, in order... 234 with TestPipeline() as p: 235 res = ( 236 p 237 | beam.Create(range(47), reshuffle=False) 238 | beam.Map(lambda t: window.TimestampedValue(t, t)) 239 | beam.WindowInto(window.FixedWindows(30)) 240 | util.BatchElements( 241 min_batch_size=5, max_batch_size=10, clock=FakeClock()) 242 | beam.Map(len)) 243 assert_that( 244 res, 245 equal_to([ 246 5, 247 5, 248 10, 249 10, # elements in [0, 30) 250 10, 251 7, # elements in [30, 47) 252 ])) 253 254 def test_global_batch_timestamps(self): 255 # Assumes a single bundle 256 with TestPipeline() as p: 257 res = ( 258 p 259 | beam.Create(range(3), reshuffle=False) 260 | util.BatchElements(min_batch_size=2, max_batch_size=2) 261 | beam.Map( 262 lambda batch, 263 timestamp=beam.DoFn.TimestampParam: (len(batch), timestamp))) 264 assert_that( 265 res, 266 equal_to([ 267 (2, GlobalWindow().max_timestamp()), 268 (1, GlobalWindow().max_timestamp()), 269 ])) 270 271 def test_sized_batches(self): 272 with TestPipeline() as p: 273 res = ( 274 p 275 | beam.Create([ 276 'a', 'a', 'aaaaaaaaaa', # First batch. 277 'aaaaaa', 'aaaaa', # Second batch. 278 'a', 'aaaaaaa', 'a', 'a' # Third batch. 279 ], reshuffle=False) 280 | util.BatchElements( 281 min_batch_size=10, max_batch_size=10, element_size_fn=len) 282 | beam.Map(lambda batch: ''.join(batch)) 283 | beam.Map(len)) 284 assert_that(res, equal_to([12, 11, 10])) 285 286 def test_target_duration(self): 287 clock = FakeClock() 288 batch_estimator = util._BatchSizeEstimator( 289 target_batch_overhead=None, target_batch_duration_secs=10, clock=clock) 290 batch_duration = lambda batch_size: 1 + .7 * batch_size 291 # 14 * .7 is as close as we can get to 10 as possible. 292 expected_sizes = [1, 2, 4, 8, 14, 14, 14] 293 actual_sizes = [] 294 for _ in range(len(expected_sizes)): 295 actual_sizes.append(batch_estimator.next_batch_size()) 296 with batch_estimator.record_time(actual_sizes[-1]): 297 clock.sleep(batch_duration(actual_sizes[-1])) 298 self.assertEqual(expected_sizes, actual_sizes) 299 300 def test_target_duration_including_fixed_cost(self): 301 clock = FakeClock() 302 batch_estimator = util._BatchSizeEstimator( 303 target_batch_overhead=None, 304 target_batch_duration_secs_including_fixed_cost=10, 305 clock=clock) 306 batch_duration = lambda batch_size: 1 + .7 * batch_size 307 # 1 + 14 * .7 is as close as we can get to 10 as possible. 308 expected_sizes = [1, 2, 4, 8, 12, 12, 12] 309 actual_sizes = [] 310 for _ in range(len(expected_sizes)): 311 actual_sizes.append(batch_estimator.next_batch_size()) 312 with batch_estimator.record_time(actual_sizes[-1]): 313 clock.sleep(batch_duration(actual_sizes[-1])) 314 self.assertEqual(expected_sizes, actual_sizes) 315 316 def test_target_overhead(self): 317 clock = FakeClock() 318 batch_estimator = util._BatchSizeEstimator( 319 target_batch_overhead=.05, target_batch_duration_secs=None, clock=clock) 320 batch_duration = lambda batch_size: 1 + .7 * batch_size 321 # At 27 items, a batch takes ~20 seconds with 5% (~1 second) overhead. 322 expected_sizes = [1, 2, 4, 8, 16, 27, 27, 27] 323 actual_sizes = [] 324 for _ in range(len(expected_sizes)): 325 actual_sizes.append(batch_estimator.next_batch_size()) 326 with batch_estimator.record_time(actual_sizes[-1]): 327 clock.sleep(batch_duration(actual_sizes[-1])) 328 self.assertEqual(expected_sizes, actual_sizes) 329 330 def test_variance(self): 331 clock = FakeClock() 332 variance = 0.25 333 batch_estimator = util._BatchSizeEstimator( 334 target_batch_overhead=.05, 335 target_batch_duration_secs=None, 336 variance=variance, 337 clock=clock) 338 batch_duration = lambda batch_size: 1 + .7 * batch_size 339 expected_target = 27 340 actual_sizes = [] 341 for _ in range(util._BatchSizeEstimator._MAX_DATA_POINTS - 1): 342 actual_sizes.append(batch_estimator.next_batch_size()) 343 with batch_estimator.record_time(actual_sizes[-1]): 344 clock.sleep(batch_duration(actual_sizes[-1])) 345 # Check that we're testing a good range of values. 346 stable_set = set(actual_sizes[-20:]) 347 self.assertGreater(len(stable_set), 3) 348 self.assertGreater( 349 min(stable_set), expected_target - expected_target * variance) 350 self.assertLess( 351 max(stable_set), expected_target + expected_target * variance) 352 353 def test_ignore_first_n_batch_size(self): 354 clock = FakeClock() 355 batch_estimator = util._BatchSizeEstimator( 356 clock=clock, ignore_first_n_seen_per_batch_size=2) 357 358 expected_sizes = [ 359 1, 1, 1, 2, 2, 2, 4, 4, 4, 8, 8, 8, 16, 16, 16, 32, 32, 32, 64, 64, 64 360 ] 361 actual_sizes = [] 362 for i in range(len(expected_sizes)): 363 actual_sizes.append(batch_estimator.next_batch_size()) 364 with batch_estimator.record_time(actual_sizes[-1]): 365 if i % 3 == 2: 366 clock.sleep(0.01) 367 else: 368 clock.sleep(1) 369 370 self.assertEqual(expected_sizes, actual_sizes) 371 372 # Check we only record the third timing. 373 expected_data_batch_sizes = [1, 2, 4, 8, 16, 32, 64] 374 actual_data_batch_sizes = [x[0] for x in batch_estimator._data] 375 self.assertEqual(expected_data_batch_sizes, actual_data_batch_sizes) 376 expected_data_timing = [0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01] 377 for i in range(len(expected_data_timing)): 378 self.assertAlmostEqual( 379 expected_data_timing[i], batch_estimator._data[i][1]) 380 381 def test_ignore_next_timing(self): 382 clock = FakeClock() 383 batch_estimator = util._BatchSizeEstimator(clock=clock) 384 batch_estimator.ignore_next_timing() 385 386 expected_sizes = [1, 1, 2, 4, 8, 16] 387 actual_sizes = [] 388 for i in range(len(expected_sizes)): 389 actual_sizes.append(batch_estimator.next_batch_size()) 390 with batch_estimator.record_time(actual_sizes[-1]): 391 if i == 0: 392 clock.sleep(1) 393 else: 394 clock.sleep(0.01) 395 396 self.assertEqual(expected_sizes, actual_sizes) 397 398 # Check the first record_time was skipped. 399 expected_data_batch_sizes = [1, 2, 4, 8, 16] 400 actual_data_batch_sizes = [x[0] for x in batch_estimator._data] 401 self.assertEqual(expected_data_batch_sizes, actual_data_batch_sizes) 402 expected_data_timing = [0.01, 0.01, 0.01, 0.01, 0.01] 403 for i in range(len(expected_data_timing)): 404 self.assertAlmostEqual( 405 expected_data_timing[i], batch_estimator._data[i][1]) 406 407 def _run_regression_test(self, linear_regression_fn, test_outliers): 408 xs = [random.random() for _ in range(10)] 409 ys = [2 * x + 1 for x in xs] 410 a, b = linear_regression_fn(xs, ys) 411 self.assertAlmostEqual(a, 1) 412 self.assertAlmostEqual(b, 2) 413 414 xs = [1 + random.random() for _ in range(100)] 415 ys = [7 * x + 5 + 0.01 * random.random() for x in xs] 416 a, b = linear_regression_fn(xs, ys) 417 self.assertAlmostEqual(a, 5, delta=0.02) 418 self.assertAlmostEqual(b, 7, delta=0.02) 419 420 # Test repeated xs 421 xs = [1 + random.random()] * 100 422 ys = [7 * x + 5 + 0.01 * random.random() for x in xs] 423 a, b = linear_regression_fn(xs, ys) 424 self.assertAlmostEqual(a, 0, delta=0.02) 425 self.assertAlmostEqual(b, sum(ys) / (len(ys) * xs[0]), delta=0.02) 426 427 if test_outliers: 428 xs = [1 + random.random() for _ in range(100)] 429 ys = [2 * x + 1 for x in xs] 430 a, b = linear_regression_fn(xs, ys) 431 self.assertAlmostEqual(a, 1) 432 self.assertAlmostEqual(b, 2) 433 434 # An outlier or two doesn't affect the result. 435 for _ in range(2): 436 xs += [10] 437 ys += [30] 438 a, b = linear_regression_fn(xs, ys) 439 self.assertAlmostEqual(a, 1) 440 self.assertAlmostEqual(b, 2) 441 442 # But enough of them, and they're no longer outliers. 443 xs += [10] * 10 444 ys += [30] * 10 445 a, b = linear_regression_fn(xs, ys) 446 self.assertLess(a, 0.5) 447 self.assertGreater(b, 2.5) 448 449 def test_no_numpy_regression(self): 450 self._run_regression_test( 451 util._BatchSizeEstimator.linear_regression_no_numpy, False) 452 453 def test_numpy_regression(self): 454 try: 455 # pylint: disable=wrong-import-order, wrong-import-position 456 import numpy as _ 457 except ImportError: 458 self.skipTest('numpy not available') 459 self._run_regression_test( 460 util._BatchSizeEstimator.linear_regression_numpy, True) 461 462 463 class IdentityWindowTest(unittest.TestCase): 464 def test_window_preserved(self): 465 expected_timestamp = timestamp.Timestamp(5) 466 expected_window = window.IntervalWindow(1.0, 2.0) 467 468 class AddWindowDoFn(beam.DoFn): 469 def process(self, element): 470 yield WindowedValue(element, expected_timestamp, [expected_window]) 471 472 with TestPipeline() as pipeline: 473 data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)] 474 expected_windows = [ 475 TestWindowedValue(kv, expected_timestamp, [expected_window]) 476 for kv in data 477 ] 478 before_identity = ( 479 pipeline 480 | 'start' >> beam.Create(data) 481 | 'add_windows' >> beam.ParDo(AddWindowDoFn())) 482 assert_that( 483 before_identity, 484 equal_to(expected_windows), 485 label='before_identity', 486 reify_windows=True) 487 after_identity = ( 488 before_identity 489 | 'window' >> beam.WindowInto( 490 beam.transforms.util._IdentityWindowFn( 491 coders.IntervalWindowCoder()))) 492 assert_that( 493 after_identity, 494 equal_to(expected_windows), 495 label='after_identity', 496 reify_windows=True) 497 498 def test_no_window_context_fails(self): 499 expected_timestamp = timestamp.Timestamp(5) 500 # Assuming the default window function is window.GlobalWindows. 501 expected_window = window.GlobalWindow() 502 503 class AddTimestampDoFn(beam.DoFn): 504 def process(self, element): 505 yield window.TimestampedValue(element, expected_timestamp) 506 507 with self.assertRaisesRegex(ValueError, r'window.*None.*add_timestamps2'): 508 with TestPipeline() as pipeline: 509 data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)] 510 expected_windows = [ 511 TestWindowedValue(kv, expected_timestamp, [expected_window]) 512 for kv in data 513 ] 514 before_identity = ( 515 pipeline 516 | 'start' >> beam.Create(data) 517 | 'add_timestamps' >> beam.ParDo(AddTimestampDoFn())) 518 assert_that( 519 before_identity, 520 equal_to(expected_windows), 521 label='before_identity', 522 reify_windows=True) 523 after_identity = ( 524 before_identity 525 | 'window' >> beam.WindowInto( 526 beam.transforms.util._IdentityWindowFn( 527 coders.GlobalWindowCoder())) 528 # This DoFn will return TimestampedValues, making 529 # WindowFn.AssignContext passed to IdentityWindowFn 530 # contain a window of None. IdentityWindowFn should 531 # raise an exception. 532 | 'add_timestamps2' >> beam.ParDo(AddTimestampDoFn())) 533 assert_that( 534 after_identity, 535 equal_to(expected_windows), 536 label='after_identity', 537 reify_windows=True) 538 539 540 class ReshuffleTest(unittest.TestCase): 541 def test_reshuffle_contents_unchanged(self): 542 with TestPipeline() as pipeline: 543 data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 3)] 544 result = (pipeline | beam.Create(data) | beam.Reshuffle()) 545 assert_that(result, equal_to(data)) 546 547 def test_reshuffle_contents_unchanged_with_buckets(self): 548 with TestPipeline() as pipeline: 549 data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 3)] 550 buckets = 2 551 result = (pipeline | beam.Create(data) | beam.Reshuffle(buckets)) 552 assert_that(result, equal_to(data)) 553 554 def test_reshuffle_contents_unchanged_with_wrong_buckets(self): 555 wrong_buckets = [0, -1, "wrong", 2.5] 556 for wrong_bucket in wrong_buckets: 557 with self.assertRaisesRegex(ValueError, 558 'If `num_buckets` is set, it has to be an ' 559 'integer greater than 0, got %s' % 560 wrong_bucket): 561 beam.Reshuffle(wrong_bucket) 562 563 def test_reshuffle_after_gbk_contents_unchanged(self): 564 with TestPipeline() as pipeline: 565 data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 3)] 566 expected_result = [(1, [1, 2, 3]), (2, [1, 2]), (3, [1])] 567 568 after_gbk = ( 569 pipeline 570 | beam.Create(data) 571 | beam.GroupByKey() 572 | beam.MapTuple(lambda k, vs: (k, sorted(vs)))) 573 assert_that(after_gbk, equal_to(expected_result), label='after_gbk') 574 after_reshuffle = after_gbk | beam.Reshuffle() 575 assert_that( 576 after_reshuffle, equal_to(expected_result), label='after_reshuffle') 577 578 def test_reshuffle_timestamps_unchanged(self): 579 with TestPipeline() as pipeline: 580 timestamp = 5 581 data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 3)] 582 expected_result = [ 583 TestWindowedValue(v, timestamp, [GlobalWindow()]) for v in data 584 ] 585 before_reshuffle = ( 586 pipeline 587 | 'start' >> beam.Create(data) 588 | 'add_timestamp' >> 589 beam.Map(lambda v: beam.window.TimestampedValue(v, timestamp))) 590 assert_that( 591 before_reshuffle, 592 equal_to(expected_result), 593 label='before_reshuffle', 594 reify_windows=True) 595 after_reshuffle = before_reshuffle | beam.Reshuffle() 596 assert_that( 597 after_reshuffle, 598 equal_to(expected_result), 599 label='after_reshuffle', 600 reify_windows=True) 601 602 def test_reshuffle_windows_unchanged(self): 603 with TestPipeline() as pipeline: 604 data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)] 605 expected_data = [ 606 TestWindowedValue(v, t - .001, [w]) 607 for (v, t, w) in [((1, contains_in_any_order([2, 1])), 608 4.0, 609 IntervalWindow(1.0, 4.0)), 610 ((2, contains_in_any_order([2, 1])), 611 4.0, 612 IntervalWindow(1.0, 4.0)), ( 613 (3, [1]), 3.0, IntervalWindow(1.0, 3.0)), ( 614 (1, [4]), 6.0, IntervalWindow(4.0, 6.0))] 615 ] 616 before_reshuffle = ( 617 pipeline 618 | 'start' >> beam.Create(data) 619 | 'add_timestamp' >> 620 beam.Map(lambda v: beam.window.TimestampedValue(v, v[1])) 621 | 'window' >> beam.WindowInto(Sessions(gap_size=2)) 622 | 'group_by_key' >> beam.GroupByKey()) 623 assert_that( 624 before_reshuffle, 625 equal_to(expected_data), 626 label='before_reshuffle', 627 reify_windows=True) 628 after_reshuffle = before_reshuffle | beam.Reshuffle() 629 assert_that( 630 after_reshuffle, 631 equal_to(expected_data), 632 label='after reshuffle', 633 reify_windows=True) 634 635 def test_reshuffle_window_fn_preserved(self): 636 any_order = contains_in_any_order 637 with TestPipeline() as pipeline: 638 data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)] 639 expected_windows = [ 640 TestWindowedValue(v, t, [w]) 641 for (v, t, w) in [((1, 1), 1.0, IntervalWindow(1.0, 3.0)), ( 642 (2, 1), 1.0, IntervalWindow(1.0, 3.0)), ( 643 (3, 1), 1.0, IntervalWindow(1.0, 3.0)), ( 644 (1, 2), 2.0, IntervalWindow(2.0, 4.0)), ( 645 (2, 2), 2.0, 646 IntervalWindow(2.0, 4.0)), ((1, 4), 647 4.0, 648 IntervalWindow(4.0, 6.0))] 649 ] 650 expected_merged_windows = [ 651 TestWindowedValue(v, t - .001, [w]) 652 for (v, t, 653 w) in [((1, any_order([2, 1])), 4.0, IntervalWindow(1.0, 4.0)), ( 654 (2, any_order([2, 1])), 4.0, IntervalWindow(1.0, 4.0)), ( 655 (3, [1]), 3.0, 656 IntervalWindow(1.0, 3.0)), ((1, [4]), 657 6.0, 658 IntervalWindow(4.0, 6.0))] 659 ] 660 before_reshuffle = ( 661 pipeline 662 | 'start' >> beam.Create(data) 663 | 'add_timestamp' >> beam.Map(lambda v: TimestampedValue(v, v[1])) 664 | 'window' >> beam.WindowInto(Sessions(gap_size=2))) 665 assert_that( 666 before_reshuffle, 667 equal_to(expected_windows), 668 label='before_reshuffle', 669 reify_windows=True) 670 after_reshuffle = before_reshuffle | beam.Reshuffle() 671 assert_that( 672 after_reshuffle, 673 equal_to(expected_windows), 674 label='after_reshuffle', 675 reify_windows=True) 676 after_group = after_reshuffle | beam.GroupByKey() 677 assert_that( 678 after_group, 679 equal_to(expected_merged_windows), 680 label='after_group', 681 reify_windows=True) 682 683 def test_reshuffle_global_window(self): 684 with TestPipeline() as pipeline: 685 data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)] 686 expected_data = [(1, [1, 2, 4]), (2, [1, 2]), (3, [1])] 687 before_reshuffle = ( 688 pipeline 689 | beam.Create(data) 690 | beam.WindowInto(GlobalWindows()) 691 | beam.GroupByKey() 692 | beam.MapTuple(lambda k, vs: (k, sorted(vs)))) 693 assert_that( 694 before_reshuffle, equal_to(expected_data), label='before_reshuffle') 695 after_reshuffle = before_reshuffle | beam.Reshuffle() 696 assert_that( 697 after_reshuffle, equal_to(expected_data), label='after reshuffle') 698 699 def test_reshuffle_sliding_window(self): 700 with TestPipeline() as pipeline: 701 data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)] 702 window_size = 2 703 expected_data = [(1, [1, 2, 4]), (2, [1, 2]), (3, [1])] * window_size 704 before_reshuffle = ( 705 pipeline 706 | beam.Create(data) 707 | beam.WindowInto(SlidingWindows(size=window_size, period=1)) 708 | beam.GroupByKey() 709 | beam.MapTuple(lambda k, vs: (k, sorted(vs)))) 710 assert_that( 711 before_reshuffle, equal_to(expected_data), label='before_reshuffle') 712 after_reshuffle = before_reshuffle | beam.Reshuffle() 713 # If Reshuffle applies the sliding window function a second time there 714 # should be extra values for each key. 715 assert_that( 716 after_reshuffle, equal_to(expected_data), label='after reshuffle') 717 718 def test_reshuffle_streaming_global_window(self): 719 options = PipelineOptions() 720 options.view_as(StandardOptions).streaming = True 721 with TestPipeline(options=options) as pipeline: 722 data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)] 723 expected_data = [(1, [1, 2, 4]), (2, [1, 2]), (3, [1])] 724 before_reshuffle = ( 725 pipeline 726 | beam.Create(data) 727 | beam.WindowInto(GlobalWindows()) 728 | beam.GroupByKey() 729 | beam.MapTuple(lambda k, vs: (k, sorted(vs)))) 730 assert_that( 731 before_reshuffle, equal_to(expected_data), label='before_reshuffle') 732 after_reshuffle = before_reshuffle | beam.Reshuffle() 733 assert_that( 734 after_reshuffle, equal_to(expected_data), label='after reshuffle') 735 736 def test_reshuffle_streaming_global_window_with_buckets(self): 737 options = PipelineOptions() 738 options.view_as(StandardOptions).streaming = True 739 with TestPipeline(options=options) as pipeline: 740 data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)] 741 expected_data = [(1, [1, 2, 4]), (2, [1, 2]), (3, [1])] 742 buckets = 2 743 before_reshuffle = ( 744 pipeline 745 | beam.Create(data) 746 | beam.WindowInto(GlobalWindows()) 747 | beam.GroupByKey() 748 | beam.MapTuple(lambda k, vs: (k, sorted(vs)))) 749 assert_that( 750 before_reshuffle, equal_to(expected_data), label='before_reshuffle') 751 after_reshuffle = before_reshuffle | beam.Reshuffle(buckets) 752 assert_that( 753 after_reshuffle, equal_to(expected_data), label='after reshuffle') 754 755 @pytest.mark.it_validatesrunner 756 def test_reshuffle_preserves_timestamps(self): 757 with TestPipeline() as pipeline: 758 759 # Create a PCollection and assign each element with a different timestamp. 760 before_reshuffle = ( 761 pipeline 762 | beam.Create([ 763 { 764 'name': 'foo', 'timestamp': MIN_TIMESTAMP 765 }, 766 { 767 'name': 'foo', 'timestamp': 0 768 }, 769 { 770 'name': 'bar', 'timestamp': 33 771 }, 772 { 773 'name': 'bar', 'timestamp': 0 774 }, 775 ]) 776 | beam.Map( 777 lambda element: beam.window.TimestampedValue( 778 element, element['timestamp']))) 779 780 # Reshuffle the PCollection above and assign the timestamp of an element 781 # to that element again. 782 after_reshuffle = before_reshuffle | beam.Reshuffle() 783 784 # Given an element, emits a string which contains the timestamp and the 785 # name field of the element. 786 def format_with_timestamp(element, timestamp=beam.DoFn.TimestampParam): 787 t = str(timestamp) 788 if timestamp == MIN_TIMESTAMP: 789 t = 'MIN_TIMESTAMP' 790 elif timestamp == MAX_TIMESTAMP: 791 t = 'MAX_TIMESTAMP' 792 return '{} - {}'.format(t, element['name']) 793 794 # Combine each element in before_reshuffle with its timestamp. 795 formatted_before_reshuffle = ( 796 before_reshuffle 797 | "Get before_reshuffle timestamp" >> beam.Map(format_with_timestamp)) 798 799 # Combine each element in after_reshuffle with its timestamp. 800 formatted_after_reshuffle = ( 801 after_reshuffle 802 | "Get after_reshuffle timestamp" >> beam.Map(format_with_timestamp)) 803 804 expected_data = [ 805 'MIN_TIMESTAMP - foo', 806 'Timestamp(0) - foo', 807 'Timestamp(33) - bar', 808 'Timestamp(0) - bar' 809 ] 810 811 # Can't compare formatted_before_reshuffle and formatted_after_reshuffle 812 # directly, because they are deferred PCollections while equal_to only 813 # takes a concrete argument. 814 assert_that( 815 formatted_before_reshuffle, 816 equal_to(expected_data), 817 label="formatted_before_reshuffle") 818 assert_that( 819 formatted_after_reshuffle, 820 equal_to(expected_data), 821 label="formatted_after_reshuffle") 822 823 824 class WithKeysTest(unittest.TestCase): 825 def setUp(self): 826 self.l = [1, 2, 3] 827 828 def test_constant_k(self): 829 with TestPipeline() as p: 830 pc = p | beam.Create(self.l) 831 with_keys = pc | util.WithKeys('k') 832 assert_that(with_keys, equal_to([('k', 1), ('k', 2), ('k', 3)], )) 833 834 def test_callable_k(self): 835 with TestPipeline() as p: 836 pc = p | beam.Create(self.l) 837 with_keys = pc | util.WithKeys(lambda x: x * x) 838 assert_that(with_keys, equal_to([(1, 1), (4, 2), (9, 3)])) 839 840 @staticmethod 841 def _test_args_kwargs_fn(x, multiply, subtract): 842 return x * multiply - subtract 843 844 def test_args_kwargs_k(self): 845 with TestPipeline() as p: 846 pc = p | beam.Create(self.l) 847 with_keys = pc | util.WithKeys( 848 WithKeysTest._test_args_kwargs_fn, 2, subtract=1) 849 assert_that(with_keys, equal_to([(1, 1), (3, 2), (5, 3)])) 850 851 def test_sideinputs(self): 852 with TestPipeline() as p: 853 pc = p | beam.Create(self.l) 854 si1 = AsList(p | "side input 1" >> beam.Create([1, 2, 3])) 855 si2 = AsSingleton(p | "side input 2" >> beam.Create([10])) 856 with_keys = pc | util.WithKeys( 857 lambda x, 858 the_list, 859 the_singleton: x + sum(the_list) + the_singleton, 860 si1, 861 the_singleton=si2) 862 assert_that(with_keys, equal_to([(17, 1), (18, 2), (19, 3)])) 863 864 865 class GroupIntoBatchesTest(unittest.TestCase): 866 NUM_ELEMENTS = 10 867 BATCH_SIZE = 5 868 869 @staticmethod 870 def _create_test_data(): 871 scientists = [ 872 "Einstein", 873 "Darwin", 874 "Copernicus", 875 "Pasteur", 876 "Curie", 877 "Faraday", 878 "Newton", 879 "Bohr", 880 "Galilei", 881 "Maxwell" 882 ] 883 884 data = [] 885 for i in range(GroupIntoBatchesTest.NUM_ELEMENTS): 886 index = i % len(scientists) 887 data.append(("key", scientists[index])) 888 return data 889 890 def test_in_global_window(self): 891 with TestPipeline() as pipeline: 892 collection = pipeline \ 893 | beam.Create(GroupIntoBatchesTest._create_test_data()) \ 894 | util.GroupIntoBatches(GroupIntoBatchesTest.BATCH_SIZE) 895 num_batches = collection | beam.combiners.Count.Globally() 896 assert_that( 897 num_batches, 898 equal_to([ 899 int( 900 math.ceil( 901 GroupIntoBatchesTest.NUM_ELEMENTS / 902 GroupIntoBatchesTest.BATCH_SIZE)) 903 ])) 904 905 def test_with_sharded_key_in_global_window(self): 906 with TestPipeline() as pipeline: 907 collection = ( 908 pipeline 909 | beam.Create(GroupIntoBatchesTest._create_test_data()) 910 | util.GroupIntoBatches.WithShardedKey( 911 GroupIntoBatchesTest.BATCH_SIZE)) 912 num_batches = collection | beam.combiners.Count.Globally() 913 assert_that( 914 num_batches, 915 equal_to([ 916 int( 917 math.ceil( 918 GroupIntoBatchesTest.NUM_ELEMENTS / 919 GroupIntoBatchesTest.BATCH_SIZE)) 920 ])) 921 922 def test_buffering_timer_in_fixed_window_streaming(self): 923 window_duration = 6 924 max_buffering_duration_secs = 100 925 926 start_time = timestamp.Timestamp(0) 927 test_stream = ( 928 TestStream().add_elements([ 929 TimestampedValue(value, start_time + i) for i, 930 value in enumerate(GroupIntoBatchesTest._create_test_data()) 931 ]).advance_processing_time(150).advance_watermark_to( 932 start_time + window_duration).advance_watermark_to( 933 start_time + window_duration + 934 1).advance_watermark_to_infinity()) 935 936 with TestPipeline(options=StandardOptions(streaming=True)) as pipeline: 937 # To trigger the processing time timer, use a fake clock with start time 938 # being Timestamp(0). 939 fake_clock = FakeClock(now=start_time) 940 941 num_elements_per_batch = ( 942 pipeline | test_stream 943 | "fixed window" >> WindowInto(FixedWindows(window_duration)) 944 | util.GroupIntoBatches( 945 GroupIntoBatchesTest.BATCH_SIZE, 946 max_buffering_duration_secs, 947 fake_clock) 948 | "count elements in batch" >> Map(lambda x: (None, len(x[1]))) 949 | GroupByKey() 950 | "global window" >> WindowInto(GlobalWindows()) 951 | FlatMapTuple(lambda k, vs: vs)) 952 953 # Window duration is 6 and batch size is 5, so output batch size 954 # should be 5 (flush because of batch size reached). 955 expected_0 = 5 956 # There is only one element left in the window so batch size 957 # should be 1 (flush because of max buffering duration reached). 958 expected_1 = 1 959 # Collection has 10 elements, there are only 4 left, so batch size should 960 # be 4 (flush because of end of window reached). 961 expected_2 = 4 962 assert_that( 963 num_elements_per_batch, 964 equal_to([expected_0, expected_1, expected_2]), 965 "assert2") 966 967 def test_buffering_timer_in_global_window_streaming(self): 968 max_buffering_duration_secs = 42 969 970 start_time = timestamp.Timestamp(0) 971 test_stream = TestStream().advance_watermark_to(start_time) 972 for i, value in enumerate(GroupIntoBatchesTest._create_test_data()): 973 test_stream.add_elements( 974 [TimestampedValue(value, start_time + i)]) \ 975 .advance_processing_time(5) 976 test_stream.advance_watermark_to( 977 start_time + GroupIntoBatchesTest.NUM_ELEMENTS + 1) \ 978 .advance_watermark_to_infinity() 979 980 with TestPipeline(options=StandardOptions(streaming=True)) as pipeline: 981 # Set a batch size larger than the total number of elements. 982 # Since we're in a global window, we would have been waiting 983 # for all the elements to arrive without the buffering time limit. 984 batch_size = GroupIntoBatchesTest.NUM_ELEMENTS * 2 985 986 # To trigger the processing time timer, use a fake clock with start time 987 # being Timestamp(0). Since the fake clock never really advances during 988 # the pipeline execution, meaning that the timer is always set to the same 989 # value, the timer will be fired on every element after the first firing. 990 fake_clock = FakeClock(now=start_time) 991 992 num_elements_per_batch = ( 993 pipeline | test_stream 994 | WindowInto( 995 GlobalWindows(), 996 trigger=Repeatedly(AfterCount(1)), 997 accumulation_mode=trigger.AccumulationMode.DISCARDING) 998 | util.GroupIntoBatches( 999 batch_size, max_buffering_duration_secs, fake_clock) 1000 | 'count elements in batch' >> Map(lambda x: (None, len(x[1]))) 1001 | GroupByKey() 1002 | FlatMapTuple(lambda k, vs: vs)) 1003 1004 # We will flush twice when the max buffering duration is reached and when 1005 # the global window ends. 1006 assert_that(num_elements_per_batch, equal_to([9, 1])) 1007 1008 def test_output_typehints(self): 1009 transform = util.GroupIntoBatches.WithShardedKey( 1010 GroupIntoBatchesTest.BATCH_SIZE) 1011 unused_input_type = typehints.Tuple[str, str] 1012 output_type = transform.infer_output_type(unused_input_type) 1013 self.assertTrue(isinstance(output_type, typehints.TupleConstraint)) 1014 k, v = output_type.tuple_types 1015 self.assertTrue(isinstance(k, ShardedKeyType)) 1016 self.assertTrue(isinstance(v, typehints.IterableTypeConstraint)) 1017 1018 with TestPipeline() as pipeline: 1019 collection = ( 1020 pipeline 1021 | beam.Create([((1, 2), 'a'), ((2, 3), 'b')]) 1022 | util.GroupIntoBatches.WithShardedKey( 1023 GroupIntoBatchesTest.BATCH_SIZE)) 1024 self.assertTrue( 1025 collection.element_type, 1026 typehints.Tuple[ 1027 ShardedKeyType[typehints.Tuple[int, int]], # type: ignore[misc] 1028 typehints.Iterable[str]]) 1029 1030 def _test_runner_api_round_trip(self, transform, urn): 1031 context = pipeline_context.PipelineContext() 1032 proto = transform.to_runner_api(context) 1033 self.assertEqual(urn, proto.urn) 1034 payload = ( 1035 proto_utils.parse_Bytes( 1036 proto.payload, beam_runner_api_pb2.GroupIntoBatchesPayload)) 1037 self.assertEqual(transform.params.batch_size, payload.batch_size) 1038 self.assertEqual( 1039 transform.params.max_buffering_duration_secs * 1000, 1040 payload.max_buffering_duration_millis) 1041 1042 transform_from_proto = ( 1043 transform.__class__.from_runner_api_parameter(None, payload, None)) 1044 self.assertIsInstance(transform_from_proto, transform.__class__) 1045 self.assertEqual(transform.params, transform_from_proto.params) 1046 1047 def test_runner_api(self): 1048 batch_size = 10 1049 max_buffering_duration_secs = [None, 0, 5] 1050 1051 for duration in max_buffering_duration_secs: 1052 self._test_runner_api_round_trip( 1053 util.GroupIntoBatches(batch_size, duration), 1054 common_urns.group_into_batches_components.GROUP_INTO_BATCHES.urn) 1055 self._test_runner_api_round_trip( 1056 util.GroupIntoBatches(batch_size), 1057 common_urns.group_into_batches_components.GROUP_INTO_BATCHES.urn) 1058 1059 for duration in max_buffering_duration_secs: 1060 self._test_runner_api_round_trip( 1061 util.GroupIntoBatches.WithShardedKey(batch_size, duration), 1062 common_urns.composites.GROUP_INTO_BATCHES_WITH_SHARDED_KEY.urn) 1063 self._test_runner_api_round_trip( 1064 util.GroupIntoBatches.WithShardedKey(batch_size), 1065 common_urns.composites.GROUP_INTO_BATCHES_WITH_SHARDED_KEY.urn) 1066 1067 1068 class ToStringTest(unittest.TestCase): 1069 def test_tostring_elements(self): 1070 with TestPipeline() as p: 1071 result = (p | beam.Create([1, 1, 2, 3]) | util.ToString.Element()) 1072 assert_that(result, equal_to(["1", "1", "2", "3"])) 1073 1074 def test_tostring_iterables(self): 1075 with TestPipeline() as p: 1076 result = ( 1077 p | beam.Create([("one", "two", "three"), ("four", "five", "six")]) 1078 | util.ToString.Iterables()) 1079 assert_that(result, equal_to(["one,two,three", "four,five,six"])) 1080 1081 def test_tostring_iterables_with_delimeter(self): 1082 with TestPipeline() as p: 1083 data = [("one", "two", "three"), ("four", "five", "six")] 1084 result = (p | beam.Create(data) | util.ToString.Iterables("\t")) 1085 assert_that(result, equal_to(["one\ttwo\tthree", "four\tfive\tsix"])) 1086 1087 def test_tostring_kvs(self): 1088 with TestPipeline() as p: 1089 result = (p | beam.Create([("one", 1), ("two", 2)]) | util.ToString.Kvs()) 1090 assert_that(result, equal_to(["one,1", "two,2"])) 1091 1092 def test_tostring_kvs_delimeter(self): 1093 with TestPipeline() as p: 1094 result = ( 1095 p | beam.Create([("one", 1), ("two", 2)]) | util.ToString.Kvs("\t")) 1096 assert_that(result, equal_to(["one\t1", "two\t2"])) 1097 1098 def test_tostring_kvs_empty_delimeter(self): 1099 with TestPipeline() as p: 1100 result = ( 1101 p | beam.Create([("one", 1), ("two", 2)]) | util.ToString.Kvs("")) 1102 assert_that(result, equal_to(["one1", "two2"])) 1103 1104 1105 class LogElementsTest(unittest.TestCase): 1106 @pytest.fixture(scope="function") 1107 def _capture_stdout_log(request, capsys): 1108 with TestPipeline() as p: 1109 result = ( 1110 p | beam.Create([ 1111 TimestampedValue( 1112 "event", 1113 datetime(2022, 10, 1, 0, 0, 0, 0, 1114 tzinfo=pytz.UTC).timestamp()), 1115 TimestampedValue( 1116 "event", 1117 datetime(2022, 10, 2, 0, 0, 0, 0, 1118 tzinfo=pytz.UTC).timestamp()), 1119 ]) 1120 | beam.WindowInto(FixedWindows(60)) 1121 | util.LogElements( 1122 prefix='prefix_', with_window=True, with_timestamp=True)) 1123 1124 request.captured_stdout = capsys.readouterr().out 1125 return result 1126 1127 @pytest.mark.usefixtures("_capture_stdout_log") 1128 def test_stdout_logs(self): 1129 assert self.captured_stdout == \ 1130 ("prefix_event, timestamp='2022-10-01T00:00:00Z', " 1131 "window(start=2022-10-01T00:00:00Z, end=2022-10-01T00:01:00Z)\n" 1132 "prefix_event, timestamp='2022-10-02T00:00:00Z', " 1133 "window(start=2022-10-02T00:00:00Z, end=2022-10-02T00:01:00Z)\n"), \ 1134 f'Received from stdout: {self.captured_stdout}' 1135 1136 def test_ptransform_output(self): 1137 with TestPipeline() as p: 1138 result = ( 1139 p 1140 | beam.Create(['a', 'b', 'c']) 1141 | util.LogElements(prefix='prefix_')) 1142 assert_that(result, equal_to(['a', 'b', 'c'])) 1143 1144 1145 class ReifyTest(unittest.TestCase): 1146 def test_timestamp(self): 1147 l = [ 1148 TimestampedValue('a', 100), 1149 TimestampedValue('b', 200), 1150 TimestampedValue('c', 300) 1151 ] 1152 expected = [ 1153 TestWindowedValue('a', 100, [GlobalWindow()]), 1154 TestWindowedValue('b', 200, [GlobalWindow()]), 1155 TestWindowedValue('c', 300, [GlobalWindow()]) 1156 ] 1157 with TestPipeline() as p: 1158 # Map(lambda x: x) PTransform is added after Create here, because when 1159 # a PCollection of TimestampedValues is created with Create PTransform, 1160 # the timestamps are not assigned to it. Adding a Map forces the 1161 # PCollection to go through a DoFn so that the PCollection consists of 1162 # the elements with timestamps assigned to them instead of a PCollection 1163 # of TimestampedValue(element, timestamp). 1164 pc = p | beam.Create(l) | beam.Map(lambda x: x) 1165 reified_pc = pc | util.Reify.Timestamp() 1166 assert_that(reified_pc, equal_to(expected), reify_windows=True) 1167 1168 def test_window(self): 1169 l = [ 1170 GlobalWindows.windowed_value('a', 100), 1171 GlobalWindows.windowed_value('b', 200), 1172 GlobalWindows.windowed_value('c', 300) 1173 ] 1174 expected = [ 1175 TestWindowedValue(('a', 100, GlobalWindow()), 100, [GlobalWindow()]), 1176 TestWindowedValue(('b', 200, GlobalWindow()), 200, [GlobalWindow()]), 1177 TestWindowedValue(('c', 300, GlobalWindow()), 300, [GlobalWindow()]) 1178 ] 1179 with TestPipeline() as p: 1180 pc = p | beam.Create(l) 1181 # Map(lambda x: x) PTransform is added after Create here, because when 1182 # a PCollection of WindowedValues is created with Create PTransform, 1183 # the windows are not assigned to it. Adding a Map forces the 1184 # PCollection to go through a DoFn so that the PCollection consists of 1185 # the elements with timestamps assigned to them instead of a PCollection 1186 # of WindowedValue(element, timestamp, window). 1187 pc = pc | beam.Map(lambda x: x) 1188 reified_pc = pc | util.Reify.Window() 1189 assert_that(reified_pc, equal_to(expected), reify_windows=True) 1190 1191 def test_timestamp_in_value(self): 1192 l = [ 1193 TimestampedValue(('a', 1), 100), 1194 TimestampedValue(('b', 2), 200), 1195 TimestampedValue(('c', 3), 300) 1196 ] 1197 expected = [ 1198 TestWindowedValue(('a', TimestampedValue(1, 100)), 1199 100, [GlobalWindow()]), 1200 TestWindowedValue(('b', TimestampedValue(2, 200)), 1201 200, [GlobalWindow()]), 1202 TestWindowedValue(('c', TimestampedValue(3, 300)), 1203 300, [GlobalWindow()]) 1204 ] 1205 with TestPipeline() as p: 1206 pc = p | beam.Create(l) | beam.Map(lambda x: x) 1207 reified_pc = pc | util.Reify.TimestampInValue() 1208 assert_that(reified_pc, equal_to(expected), reify_windows=True) 1209 1210 def test_window_in_value(self): 1211 l = [ 1212 GlobalWindows.windowed_value(('a', 1), 100), 1213 GlobalWindows.windowed_value(('b', 2), 200), 1214 GlobalWindows.windowed_value(('c', 3), 300) 1215 ] 1216 expected = [ 1217 TestWindowedValue(('a', (1, 100, GlobalWindow())), 1218 100, [GlobalWindow()]), 1219 TestWindowedValue(('b', (2, 200, GlobalWindow())), 1220 200, [GlobalWindow()]), 1221 TestWindowedValue(('c', (3, 300, GlobalWindow())), 1222 300, [GlobalWindow()]) 1223 ] 1224 with TestPipeline() as p: 1225 # Map(lambda x: x) hack is used for the same reason here. 1226 # Also, this makes the typehint on Reify.WindowInValue work. 1227 pc = p | beam.Create(l) | beam.Map(lambda x: x) 1228 reified_pc = pc | util.Reify.WindowInValue() 1229 assert_that(reified_pc, equal_to(expected), reify_windows=True) 1230 1231 1232 class RegexTest(unittest.TestCase): 1233 def test_find(self): 1234 with TestPipeline() as p: 1235 result = ( 1236 p | beam.Create(["aj", "xj", "yj", "zj"]) 1237 | util.Regex.find("[xyz]")) 1238 assert_that(result, equal_to(["x", "y", "z"])) 1239 1240 def test_find_pattern(self): 1241 with TestPipeline() as p: 1242 rc = re.compile("[xyz]") 1243 result = (p | beam.Create(["aj", "xj", "yj", "zj"]) | util.Regex.find(rc)) 1244 assert_that(result, equal_to(["x", "y", "z"])) 1245 1246 def test_find_group(self): 1247 with TestPipeline() as p: 1248 result = ( 1249 p | beam.Create(["aj", "xj", "yj", "zj"]) 1250 | util.Regex.find("([xyz])j", group=1)) 1251 assert_that(result, equal_to(["x", "y", "z"])) 1252 1253 def test_find_empty(self): 1254 with TestPipeline() as p: 1255 result = ( 1256 p | beam.Create(["a", "b", "c", "d"]) 1257 | util.Regex.find("[xyz]")) 1258 assert_that(result, equal_to([])) 1259 1260 def test_find_group_name(self): 1261 with TestPipeline() as p: 1262 result = ( 1263 p | beam.Create(["aj", "xj", "yj", "zj"]) 1264 | util.Regex.find("(?P<namedgroup>[xyz])j", group="namedgroup")) 1265 assert_that(result, equal_to(["x", "y", "z"])) 1266 1267 def test_find_group_name_pattern(self): 1268 with TestPipeline() as p: 1269 rc = re.compile("(?P<namedgroup>[xyz])j") 1270 result = ( 1271 p | beam.Create(["aj", "xj", "yj", "zj"]) 1272 | util.Regex.find(rc, group="namedgroup")) 1273 assert_that(result, equal_to(["x", "y", "z"])) 1274 1275 def test_find_all_groups(self): 1276 data = ["abb ax abbb", "abc qwerty abcabcd xyz"] 1277 with TestPipeline() as p: 1278 pcol = (p | beam.Create(data)) 1279 1280 assert_that( 1281 pcol | 'with default values' >> util.Regex.find_all('a(b*)'), 1282 equal_to([['abb', 'a', 'abbb'], ['ab', 'ab', 'ab']]), 1283 label='CheckWithDefaultValues') 1284 1285 assert_that( 1286 pcol | 'group 1' >> util.Regex.find_all('a(b*)', 1), 1287 equal_to([['b', 'b', 'b'], ['bb', '', 'bbb']]), 1288 label='CheckWithGroup1') 1289 1290 assert_that( 1291 pcol | 'group 1 non empty' >> util.Regex.find_all( 1292 'a(b*)', 1, outputEmpty=False), 1293 equal_to([['b', 'b', 'b'], ['bb', 'bbb']]), 1294 label='CheckGroup1NonEmpty') 1295 1296 assert_that( 1297 pcol | 'named group' >> util.Regex.find_all( 1298 'a(?P<namedgroup>b*)', 'namedgroup'), 1299 equal_to([['b', 'b', 'b'], ['bb', '', 'bbb']]), 1300 label='CheckNamedGroup') 1301 1302 assert_that( 1303 pcol | 'all groups' >> util.Regex.find_all( 1304 'a(?P<namedgroup>b*)', util.Regex.ALL), 1305 equal_to([[('ab', 'b'), ('ab', 'b'), ('ab', 'b')], 1306 [('abb', 'bb'), ('a', ''), ('abbb', 'bbb')]]), 1307 label='CheckAllGroups') 1308 1309 assert_that( 1310 pcol | 'all non empty groups' >> util.Regex.find_all( 1311 'a(b*)', util.Regex.ALL, outputEmpty=False), 1312 equal_to([[('ab', 'b'), ('ab', 'b'), ('ab', 'b')], 1313 [('abb', 'bb'), ('abbb', 'bbb')]]), 1314 label='CheckAllNonEmptyGroups') 1315 1316 def test_find_kv(self): 1317 with TestPipeline() as p: 1318 pcol = (p | beam.Create(['a b c d'])) 1319 assert_that( 1320 pcol | 'key 1' >> util.Regex.find_kv( 1321 'a (b) (c)', 1322 1, 1323 ), 1324 equal_to([('b', 'a b c')]), 1325 label='CheckKey1') 1326 1327 assert_that( 1328 pcol | 'key 1 group 1' >> util.Regex.find_kv('a (b) (c)', 1, 2), 1329 equal_to([('b', 'c')]), 1330 label='CheckKey1Group1') 1331 1332 def test_find_kv_pattern(self): 1333 with TestPipeline() as p: 1334 rc = re.compile("a (b) (c)") 1335 result = (p | beam.Create(["a b c"]) | util.Regex.find_kv(rc, 1, 2)) 1336 assert_that(result, equal_to([("b", "c")])) 1337 1338 def test_find_kv_none(self): 1339 with TestPipeline() as p: 1340 result = ( 1341 p | beam.Create(["x y z"]) 1342 | util.Regex.find_kv("a (b) (c)", 1, 2)) 1343 assert_that(result, equal_to([])) 1344 1345 def test_match(self): 1346 with TestPipeline() as p: 1347 result = ( 1348 p | beam.Create(["a", "x", "y", "z"]) 1349 | util.Regex.matches("[xyz]")) 1350 assert_that(result, equal_to(["x", "y", "z"])) 1351 1352 with TestPipeline() as p: 1353 result = ( 1354 p | beam.Create(["a", "ax", "yby", "zzc"]) 1355 | util.Regex.matches("[xyz]")) 1356 assert_that(result, equal_to(["y", "z"])) 1357 1358 def test_match_entire_line(self): 1359 with TestPipeline() as p: 1360 result = ( 1361 p | beam.Create(["a", "x", "y", "ay", "zz"]) 1362 | util.Regex.matches("[xyz]$")) 1363 assert_that(result, equal_to(["x", "y"])) 1364 1365 def test_match_pattern(self): 1366 with TestPipeline() as p: 1367 rc = re.compile("[xyz]") 1368 result = (p | beam.Create(["a", "x", "y", "z"]) | util.Regex.matches(rc)) 1369 assert_that(result, equal_to(["x", "y", "z"])) 1370 1371 def test_match_none(self): 1372 with TestPipeline() as p: 1373 result = ( 1374 p | beam.Create(["a", "b", "c", "d"]) 1375 | util.Regex.matches("[xyz]")) 1376 assert_that(result, equal_to([])) 1377 1378 def test_match_group(self): 1379 with TestPipeline() as p: 1380 result = ( 1381 p | beam.Create(["a", "x xxx", "x yyy", "x zzz"]) 1382 | util.Regex.matches("x ([xyz]*)", 1)) 1383 assert_that(result, equal_to(("xxx", "yyy", "zzz"))) 1384 1385 def test_match_group_name(self): 1386 with TestPipeline() as p: 1387 result = ( 1388 p | beam.Create(["a", "x xxx", "x yyy", "x zzz"]) 1389 | util.Regex.matches("x (?P<namedgroup>[xyz]*)", 'namedgroup')) 1390 assert_that(result, equal_to(("xxx", "yyy", "zzz"))) 1391 1392 def test_match_group_name_pattern(self): 1393 with TestPipeline() as p: 1394 rc = re.compile("x (?P<namedgroup>[xyz]*)") 1395 result = ( 1396 p | beam.Create(["a", "x xxx", "x yyy", "x zzz"]) 1397 | util.Regex.matches(rc, 'namedgroup')) 1398 assert_that(result, equal_to(("xxx", "yyy", "zzz"))) 1399 1400 def test_match_group_empty(self): 1401 with TestPipeline() as p: 1402 result = ( 1403 p | beam.Create(["a", "b", "c", "d"]) 1404 | util.Regex.matches("x (?P<namedgroup>[xyz]*)", 'namedgroup')) 1405 assert_that(result, equal_to([])) 1406 1407 def test_all_matched(self): 1408 with TestPipeline() as p: 1409 result = ( 1410 p | beam.Create(["a x", "x x", "y y", "z z"]) 1411 | util.Regex.all_matches("([xyz]) ([xyz])")) 1412 expected_result = [["x x", "x", "x"], ["y y", "y", "y"], 1413 ["z z", "z", "z"]] 1414 assert_that(result, equal_to(expected_result)) 1415 1416 def test_all_matched_pattern(self): 1417 with TestPipeline() as p: 1418 rc = re.compile("([xyz]) ([xyz])") 1419 result = ( 1420 p | beam.Create(["a x", "x x", "y y", "z z"]) 1421 | util.Regex.all_matches(rc)) 1422 expected_result = [["x x", "x", "x"], ["y y", "y", "y"], 1423 ["z z", "z", "z"]] 1424 assert_that(result, equal_to(expected_result)) 1425 1426 def test_match_group_kv(self): 1427 with TestPipeline() as p: 1428 result = ( 1429 p | beam.Create(["a b c"]) 1430 | util.Regex.matches_kv("a (b) (c)", 1, 2)) 1431 assert_that(result, equal_to([("b", "c")])) 1432 1433 def test_match_group_kv_pattern(self): 1434 with TestPipeline() as p: 1435 rc = re.compile("a (b) (c)") 1436 pcol = (p | beam.Create(["a b c"])) 1437 assert_that( 1438 pcol | 'key 1' >> util.Regex.matches_kv(rc, 1), 1439 equal_to([("b", "a b c")]), 1440 label="CheckKey1") 1441 1442 assert_that( 1443 pcol | 'key 1 group 2' >> util.Regex.matches_kv(rc, 1, 2), 1444 equal_to([("b", "c")]), 1445 label="CheckKey1Group2") 1446 1447 def test_match_group_kv_none(self): 1448 with TestPipeline() as p: 1449 result = ( 1450 p | beam.Create(["x y z"]) 1451 | util.Regex.matches_kv("a (b) (c)", 1, 2)) 1452 assert_that(result, equal_to([])) 1453 1454 def test_match_kv_group_names(self): 1455 with TestPipeline() as p: 1456 result = ( 1457 p | beam.Create(["a b c"]) | util.Regex.matches_kv( 1458 "a (?P<keyname>b) (?P<valuename>c)", 'keyname', 'valuename')) 1459 assert_that(result, equal_to([("b", "c")])) 1460 1461 def test_match_kv_group_names_pattern(self): 1462 with TestPipeline() as p: 1463 rc = re.compile("a (?P<keyname>b) (?P<valuename>c)") 1464 result = ( 1465 p | beam.Create(["a b c"]) 1466 | util.Regex.matches_kv(rc, 'keyname', 'valuename')) 1467 assert_that(result, equal_to([("b", "c")])) 1468 1469 def test_match_kv_group_name_none(self): 1470 with TestPipeline() as p: 1471 result = ( 1472 p | beam.Create(["x y z"]) | util.Regex.matches_kv( 1473 "a (?P<keyname>b) (?P<valuename>c)", 'keyname', 'valuename')) 1474 assert_that(result, equal_to([])) 1475 1476 def test_replace_all(self): 1477 with TestPipeline() as p: 1478 result = ( 1479 p | beam.Create(["xj", "yj", "zj"]) 1480 | util.Regex.replace_all("[xyz]", "new")) 1481 assert_that(result, equal_to(["newj", "newj", "newj"])) 1482 1483 def test_replace_all_mixed(self): 1484 with TestPipeline() as p: 1485 result = ( 1486 p | beam.Create(["abc", "xj", "yj", "zj", "def"]) 1487 | util.Regex.replace_all("[xyz]", 'new')) 1488 assert_that(result, equal_to(["abc", "newj", "newj", "newj", "def"])) 1489 1490 def test_replace_all_mixed_pattern(self): 1491 with TestPipeline() as p: 1492 rc = re.compile("[xyz]") 1493 result = ( 1494 p | beam.Create(["abc", "xj", "yj", "zj", "def"]) 1495 | util.Regex.replace_all(rc, 'new')) 1496 assert_that(result, equal_to(["abc", "newj", "newj", "newj", "def"])) 1497 1498 def test_replace_first(self): 1499 with TestPipeline() as p: 1500 result = ( 1501 p | beam.Create(["xjx", "yjy", "zjz"]) 1502 | util.Regex.replace_first("[xyz]", 'new')) 1503 assert_that(result, equal_to(["newjx", "newjy", "newjz"])) 1504 1505 def test_replace_first_mixed(self): 1506 with TestPipeline() as p: 1507 result = ( 1508 p | beam.Create(["abc", "xjx", "yjy", "zjz", "def"]) 1509 | util.Regex.replace_first("[xyz]", 'new')) 1510 assert_that(result, equal_to(["abc", "newjx", "newjy", "newjz", "def"])) 1511 1512 def test_replace_first_mixed_pattern(self): 1513 with TestPipeline() as p: 1514 rc = re.compile("[xyz]") 1515 result = ( 1516 p | beam.Create(["abc", "xjx", "yjy", "zjz", "def"]) 1517 | util.Regex.replace_first(rc, 'new')) 1518 assert_that(result, equal_to(["abc", "newjx", "newjy", "newjz", "def"])) 1519 1520 def test_split(self): 1521 with TestPipeline() as p: 1522 data = ["The quick brown fox jumps over the lazy dog"] 1523 result = (p | beam.Create(data) | util.Regex.split("\\W+")) 1524 expected_result = [[ 1525 "The", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog" 1526 ]] 1527 assert_that(result, equal_to(expected_result)) 1528 1529 def test_split_pattern(self): 1530 with TestPipeline() as p: 1531 data = ["The quick brown fox jumps over the lazy dog"] 1532 rc = re.compile("\\W+") 1533 result = (p | beam.Create(data) | util.Regex.split(rc)) 1534 expected_result = [[ 1535 "The", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog" 1536 ]] 1537 assert_that(result, equal_to(expected_result)) 1538 1539 def test_split_with_empty(self): 1540 with TestPipeline() as p: 1541 data = ["The quick brown fox jumps over the lazy dog"] 1542 result = (p | beam.Create(data) | util.Regex.split("\\s", True)) 1543 expected_result = [[ 1544 'The', 1545 '', 1546 'quick', 1547 '', 1548 '', 1549 'brown', 1550 'fox', 1551 'jumps', 1552 'over', 1553 '', 1554 '', 1555 '', 1556 'the', 1557 'lazy', 1558 'dog' 1559 ]] 1560 assert_that(result, equal_to(expected_result)) 1561 1562 def test_split_without_empty(self): 1563 with TestPipeline() as p: 1564 data = ["The quick brown fox jumps over the lazy dog"] 1565 result = (p | beam.Create(data) | util.Regex.split("\\s", False)) 1566 expected_result = [[ 1567 "The", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog" 1568 ]] 1569 assert_that(result, equal_to(expected_result)) 1570 1571 1572 if __name__ == '__main__': 1573 logging.getLogger().setLevel(logging.INFO) 1574 unittest.main()