github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/transforms/util_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Unit tests for the transform.util classes."""
    19  
    20  # pytype: skip-file
    21  
    22  import logging
    23  import math
    24  import random
    25  import re
    26  import time
    27  import unittest
    28  import warnings
    29  from datetime import datetime
    30  
    31  import pytest
    32  import pytz
    33  
    34  import apache_beam as beam
    35  from apache_beam import GroupByKey
    36  from apache_beam import Map
    37  from apache_beam import WindowInto
    38  from apache_beam.coders import coders
    39  from apache_beam.metrics import MetricsFilter
    40  from apache_beam.options.pipeline_options import PipelineOptions
    41  from apache_beam.options.pipeline_options import StandardOptions
    42  from apache_beam.portability import common_urns
    43  from apache_beam.portability.api import beam_runner_api_pb2
    44  from apache_beam.pvalue import AsList
    45  from apache_beam.pvalue import AsSingleton
    46  from apache_beam.runners import pipeline_context
    47  from apache_beam.testing.test_pipeline import TestPipeline
    48  from apache_beam.testing.test_stream import TestStream
    49  from apache_beam.testing.util import SortLists
    50  from apache_beam.testing.util import TestWindowedValue
    51  from apache_beam.testing.util import assert_that
    52  from apache_beam.testing.util import contains_in_any_order
    53  from apache_beam.testing.util import equal_to
    54  from apache_beam.transforms import trigger
    55  from apache_beam.transforms import util
    56  from apache_beam.transforms import window
    57  from apache_beam.transforms.core import FlatMapTuple
    58  from apache_beam.transforms.trigger import AfterCount
    59  from apache_beam.transforms.trigger import Repeatedly
    60  from apache_beam.transforms.window import FixedWindows
    61  from apache_beam.transforms.window import GlobalWindow
    62  from apache_beam.transforms.window import GlobalWindows
    63  from apache_beam.transforms.window import IntervalWindow
    64  from apache_beam.transforms.window import Sessions
    65  from apache_beam.transforms.window import SlidingWindows
    66  from apache_beam.transforms.window import TimestampedValue
    67  from apache_beam.typehints import typehints
    68  from apache_beam.typehints.sharded_key_type import ShardedKeyType
    69  from apache_beam.utils import proto_utils
    70  from apache_beam.utils import timestamp
    71  from apache_beam.utils.timestamp import MAX_TIMESTAMP
    72  from apache_beam.utils.timestamp import MIN_TIMESTAMP
    73  from apache_beam.utils.windowed_value import WindowedValue
    74  
    75  warnings.filterwarnings(
    76      'ignore', category=FutureWarning, module='apache_beam.transform.util_test')
    77  
    78  
    79  class CoGroupByKeyTest(unittest.TestCase):
    80    def test_co_group_by_key_on_tuple(self):
    81      with TestPipeline() as pipeline:
    82        pcoll_1 = pipeline | 'Start 1' >> beam.Create([('a', 1), ('a', 2),
    83                                                       ('b', 3), ('c', 4)])
    84        pcoll_2 = pipeline | 'Start 2' >> beam.Create([('a', 5), ('a', 6),
    85                                                       ('c', 7), ('c', 8)])
    86        result = (pcoll_1, pcoll_2) | beam.CoGroupByKey() | SortLists
    87        assert_that(
    88            result,
    89            equal_to([('a', ([1, 2], [5, 6])), ('b', ([3], [])),
    90                      ('c', ([4], [7, 8]))]))
    91  
    92    def test_co_group_by_key_on_iterable(self):
    93      with TestPipeline() as pipeline:
    94        pcoll_1 = pipeline | 'Start 1' >> beam.Create([('a', 1), ('a', 2),
    95                                                       ('b', 3), ('c', 4)])
    96        pcoll_2 = pipeline | 'Start 2' >> beam.Create([('a', 5), ('a', 6),
    97                                                       ('c', 7), ('c', 8)])
    98        result = iter([pcoll_1, pcoll_2]) | beam.CoGroupByKey() | SortLists
    99        assert_that(
   100            result,
   101            equal_to([('a', ([1, 2], [5, 6])), ('b', ([3], [])),
   102                      ('c', ([4], [7, 8]))]))
   103  
   104    def test_co_group_by_key_on_list(self):
   105      with TestPipeline() as pipeline:
   106        pcoll_1 = pipeline | 'Start 1' >> beam.Create([('a', 1), ('a', 2),
   107                                                       ('b', 3), ('c', 4)])
   108        pcoll_2 = pipeline | 'Start 2' >> beam.Create([('a', 5), ('a', 6),
   109                                                       ('c', 7), ('c', 8)])
   110        result = [pcoll_1, pcoll_2] | beam.CoGroupByKey() | SortLists
   111        assert_that(
   112            result,
   113            equal_to([('a', ([1, 2], [5, 6])), ('b', ([3], [])),
   114                      ('c', ([4], [7, 8]))]))
   115  
   116    def test_co_group_by_key_on_dict(self):
   117      with TestPipeline() as pipeline:
   118        pcoll_1 = pipeline | 'Start 1' >> beam.Create([('a', 1), ('a', 2),
   119                                                       ('b', 3), ('c', 4)])
   120        pcoll_2 = pipeline | 'Start 2' >> beam.Create([('a', 5), ('a', 6),
   121                                                       ('c', 7), ('c', 8)])
   122        result = {'X': pcoll_1, 'Y': pcoll_2} | beam.CoGroupByKey() | SortLists
   123        assert_that(
   124            result,
   125            equal_to([('a', {
   126                'X': [1, 2], 'Y': [5, 6]
   127            }), ('b', {
   128                'X': [3], 'Y': []
   129            }), ('c', {
   130                'X': [4], 'Y': [7, 8]
   131            })]))
   132  
   133    def test_co_group_by_key_on_dict_with_tuple_keys(self):
   134      with TestPipeline() as pipeline:
   135        key = ('a', ('b', 'c'))
   136        pcoll_1 = pipeline | 'Start 1' >> beam.Create([(key, 1)])
   137        pcoll_2 = pipeline | 'Start 2' >> beam.Create([(key, 2)])
   138        result = {'X': pcoll_1, 'Y': pcoll_2} | beam.CoGroupByKey() | SortLists
   139        assert_that(result, equal_to([(key, {'X': [1], 'Y': [2]})]))
   140  
   141    def test_co_group_by_key_on_empty(self):
   142      with TestPipeline() as pipeline:
   143        assert_that(
   144            tuple() | 'EmptyTuple' >> beam.CoGroupByKey(pipeline=pipeline),
   145            equal_to([]),
   146            label='AssertEmptyTuple')
   147        assert_that([] | 'EmptyList' >> beam.CoGroupByKey(pipeline=pipeline),
   148                    equal_to([]),
   149                    label='AssertEmptyList')
   150        assert_that(
   151            iter([]) | 'EmptyIterable' >> beam.CoGroupByKey(pipeline=pipeline),
   152            equal_to([]),
   153            label='AssertEmptyIterable')
   154        assert_that({} | 'EmptyDict' >> beam.CoGroupByKey(pipeline=pipeline),
   155                    equal_to([]),
   156                    label='AssertEmptyDict')
   157  
   158    def test_co_group_by_key_on_one(self):
   159      with TestPipeline() as pipeline:
   160        pcoll = pipeline | beam.Create([('a', 1), ('b', 2)])
   161        expected = [('a', ([1], )), ('b', ([2], ))]
   162        assert_that((pcoll, ) | 'OneTuple' >> beam.CoGroupByKey(),
   163                    equal_to(expected),
   164                    label='AssertOneTuple')
   165        assert_that([pcoll] | 'OneList' >> beam.CoGroupByKey(),
   166                    equal_to(expected),
   167                    label='AssertOneList')
   168        assert_that(
   169            iter([pcoll]) | 'OneIterable' >> beam.CoGroupByKey(),
   170            equal_to(expected),
   171            label='AssertOneIterable')
   172        assert_that({'tag': pcoll}
   173                    | 'OneDict' >> beam.CoGroupByKey()
   174                    | beam.MapTuple(lambda k, v: (k, (v['tag'], ))),
   175                    equal_to(expected),
   176                    label='AssertOneDict')
   177  
   178  
   179  class FakeClock(object):
   180    def __init__(self, now=time.time()):
   181      self._now = now
   182  
   183    def __call__(self):
   184      return self._now
   185  
   186    def sleep(self, duration):
   187      self._now += duration
   188  
   189  
   190  class BatchElementsTest(unittest.TestCase):
   191    def test_constant_batch(self):
   192      # Assumes a single bundle...
   193      p = TestPipeline()
   194      output = (
   195          p
   196          | beam.Create(range(35))
   197          | util.BatchElements(min_batch_size=10, max_batch_size=10)
   198          | beam.Map(len))
   199      assert_that(output, equal_to([10, 10, 10, 5]))
   200      res = p.run()
   201      res.wait_until_finish()
   202      metrics = res.metrics()
   203      results = metrics.query(MetricsFilter().with_name("batch_size"))
   204      self.assertEqual(len(results["distributions"]), 1)
   205  
   206    def test_constant_batch_no_metrics(self):
   207      p = TestPipeline()
   208      output = (
   209          p
   210          | beam.Create(range(35))
   211          | util.BatchElements(
   212              min_batch_size=10, max_batch_size=10, record_metrics=False)
   213          | beam.Map(len))
   214      assert_that(output, equal_to([10, 10, 10, 5]))
   215      res = p.run()
   216      res.wait_until_finish()
   217      metrics = res.metrics()
   218      results = metrics.query(MetricsFilter().with_name("batch_size"))
   219      self.assertEqual(len(results["distributions"]), 0)
   220  
   221    def test_grows_to_max_batch(self):
   222      # Assumes a single bundle...
   223      with TestPipeline() as p:
   224        res = (
   225            p
   226            | beam.Create(range(164))
   227            | util.BatchElements(
   228                min_batch_size=1, max_batch_size=50, clock=FakeClock())
   229            | beam.Map(len))
   230        assert_that(res, equal_to([1, 1, 2, 4, 8, 16, 32, 50, 50]))
   231  
   232    def test_windowed_batches(self):
   233      # Assumes a single bundle, in order...
   234      with TestPipeline() as p:
   235        res = (
   236            p
   237            | beam.Create(range(47), reshuffle=False)
   238            | beam.Map(lambda t: window.TimestampedValue(t, t))
   239            | beam.WindowInto(window.FixedWindows(30))
   240            | util.BatchElements(
   241                min_batch_size=5, max_batch_size=10, clock=FakeClock())
   242            | beam.Map(len))
   243        assert_that(
   244            res,
   245            equal_to([
   246                5,
   247                5,
   248                10,
   249                10,  # elements in [0, 30)
   250                10,
   251                7,  # elements in [30, 47)
   252            ]))
   253  
   254    def test_global_batch_timestamps(self):
   255      # Assumes a single bundle
   256      with TestPipeline() as p:
   257        res = (
   258            p
   259            | beam.Create(range(3), reshuffle=False)
   260            | util.BatchElements(min_batch_size=2, max_batch_size=2)
   261            | beam.Map(
   262                lambda batch,
   263                timestamp=beam.DoFn.TimestampParam: (len(batch), timestamp)))
   264        assert_that(
   265            res,
   266            equal_to([
   267                (2, GlobalWindow().max_timestamp()),
   268                (1, GlobalWindow().max_timestamp()),
   269            ]))
   270  
   271    def test_sized_batches(self):
   272      with TestPipeline() as p:
   273        res = (
   274            p
   275            | beam.Create([
   276                'a', 'a', 'aaaaaaaaaa',  # First batch.
   277                'aaaaaa', 'aaaaa',       # Second batch.
   278                'a', 'aaaaaaa', 'a', 'a' # Third batch.
   279                ], reshuffle=False)
   280            | util.BatchElements(
   281                min_batch_size=10, max_batch_size=10, element_size_fn=len)
   282            | beam.Map(lambda batch: ''.join(batch))
   283            | beam.Map(len))
   284        assert_that(res, equal_to([12, 11, 10]))
   285  
   286    def test_target_duration(self):
   287      clock = FakeClock()
   288      batch_estimator = util._BatchSizeEstimator(
   289          target_batch_overhead=None, target_batch_duration_secs=10, clock=clock)
   290      batch_duration = lambda batch_size: 1 + .7 * batch_size
   291      # 14 * .7 is as close as we can get to 10 as possible.
   292      expected_sizes = [1, 2, 4, 8, 14, 14, 14]
   293      actual_sizes = []
   294      for _ in range(len(expected_sizes)):
   295        actual_sizes.append(batch_estimator.next_batch_size())
   296        with batch_estimator.record_time(actual_sizes[-1]):
   297          clock.sleep(batch_duration(actual_sizes[-1]))
   298      self.assertEqual(expected_sizes, actual_sizes)
   299  
   300    def test_target_duration_including_fixed_cost(self):
   301      clock = FakeClock()
   302      batch_estimator = util._BatchSizeEstimator(
   303          target_batch_overhead=None,
   304          target_batch_duration_secs_including_fixed_cost=10,
   305          clock=clock)
   306      batch_duration = lambda batch_size: 1 + .7 * batch_size
   307      # 1 + 14 * .7 is as close as we can get to 10 as possible.
   308      expected_sizes = [1, 2, 4, 8, 12, 12, 12]
   309      actual_sizes = []
   310      for _ in range(len(expected_sizes)):
   311        actual_sizes.append(batch_estimator.next_batch_size())
   312        with batch_estimator.record_time(actual_sizes[-1]):
   313          clock.sleep(batch_duration(actual_sizes[-1]))
   314      self.assertEqual(expected_sizes, actual_sizes)
   315  
   316    def test_target_overhead(self):
   317      clock = FakeClock()
   318      batch_estimator = util._BatchSizeEstimator(
   319          target_batch_overhead=.05, target_batch_duration_secs=None, clock=clock)
   320      batch_duration = lambda batch_size: 1 + .7 * batch_size
   321      # At 27 items, a batch takes ~20 seconds with 5% (~1 second) overhead.
   322      expected_sizes = [1, 2, 4, 8, 16, 27, 27, 27]
   323      actual_sizes = []
   324      for _ in range(len(expected_sizes)):
   325        actual_sizes.append(batch_estimator.next_batch_size())
   326        with batch_estimator.record_time(actual_sizes[-1]):
   327          clock.sleep(batch_duration(actual_sizes[-1]))
   328      self.assertEqual(expected_sizes, actual_sizes)
   329  
   330    def test_variance(self):
   331      clock = FakeClock()
   332      variance = 0.25
   333      batch_estimator = util._BatchSizeEstimator(
   334          target_batch_overhead=.05,
   335          target_batch_duration_secs=None,
   336          variance=variance,
   337          clock=clock)
   338      batch_duration = lambda batch_size: 1 + .7 * batch_size
   339      expected_target = 27
   340      actual_sizes = []
   341      for _ in range(util._BatchSizeEstimator._MAX_DATA_POINTS - 1):
   342        actual_sizes.append(batch_estimator.next_batch_size())
   343        with batch_estimator.record_time(actual_sizes[-1]):
   344          clock.sleep(batch_duration(actual_sizes[-1]))
   345      # Check that we're testing a good range of values.
   346      stable_set = set(actual_sizes[-20:])
   347      self.assertGreater(len(stable_set), 3)
   348      self.assertGreater(
   349          min(stable_set), expected_target - expected_target * variance)
   350      self.assertLess(
   351          max(stable_set), expected_target + expected_target * variance)
   352  
   353    def test_ignore_first_n_batch_size(self):
   354      clock = FakeClock()
   355      batch_estimator = util._BatchSizeEstimator(
   356          clock=clock, ignore_first_n_seen_per_batch_size=2)
   357  
   358      expected_sizes = [
   359          1, 1, 1, 2, 2, 2, 4, 4, 4, 8, 8, 8, 16, 16, 16, 32, 32, 32, 64, 64, 64
   360      ]
   361      actual_sizes = []
   362      for i in range(len(expected_sizes)):
   363        actual_sizes.append(batch_estimator.next_batch_size())
   364        with batch_estimator.record_time(actual_sizes[-1]):
   365          if i % 3 == 2:
   366            clock.sleep(0.01)
   367          else:
   368            clock.sleep(1)
   369  
   370      self.assertEqual(expected_sizes, actual_sizes)
   371  
   372      # Check we only record the third timing.
   373      expected_data_batch_sizes = [1, 2, 4, 8, 16, 32, 64]
   374      actual_data_batch_sizes = [x[0] for x in batch_estimator._data]
   375      self.assertEqual(expected_data_batch_sizes, actual_data_batch_sizes)
   376      expected_data_timing = [0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01]
   377      for i in range(len(expected_data_timing)):
   378        self.assertAlmostEqual(
   379            expected_data_timing[i], batch_estimator._data[i][1])
   380  
   381    def test_ignore_next_timing(self):
   382      clock = FakeClock()
   383      batch_estimator = util._BatchSizeEstimator(clock=clock)
   384      batch_estimator.ignore_next_timing()
   385  
   386      expected_sizes = [1, 1, 2, 4, 8, 16]
   387      actual_sizes = []
   388      for i in range(len(expected_sizes)):
   389        actual_sizes.append(batch_estimator.next_batch_size())
   390        with batch_estimator.record_time(actual_sizes[-1]):
   391          if i == 0:
   392            clock.sleep(1)
   393          else:
   394            clock.sleep(0.01)
   395  
   396      self.assertEqual(expected_sizes, actual_sizes)
   397  
   398      # Check the first record_time was skipped.
   399      expected_data_batch_sizes = [1, 2, 4, 8, 16]
   400      actual_data_batch_sizes = [x[0] for x in batch_estimator._data]
   401      self.assertEqual(expected_data_batch_sizes, actual_data_batch_sizes)
   402      expected_data_timing = [0.01, 0.01, 0.01, 0.01, 0.01]
   403      for i in range(len(expected_data_timing)):
   404        self.assertAlmostEqual(
   405            expected_data_timing[i], batch_estimator._data[i][1])
   406  
   407    def _run_regression_test(self, linear_regression_fn, test_outliers):
   408      xs = [random.random() for _ in range(10)]
   409      ys = [2 * x + 1 for x in xs]
   410      a, b = linear_regression_fn(xs, ys)
   411      self.assertAlmostEqual(a, 1)
   412      self.assertAlmostEqual(b, 2)
   413  
   414      xs = [1 + random.random() for _ in range(100)]
   415      ys = [7 * x + 5 + 0.01 * random.random() for x in xs]
   416      a, b = linear_regression_fn(xs, ys)
   417      self.assertAlmostEqual(a, 5, delta=0.02)
   418      self.assertAlmostEqual(b, 7, delta=0.02)
   419  
   420      # Test repeated xs
   421      xs = [1 + random.random()] * 100
   422      ys = [7 * x + 5 + 0.01 * random.random() for x in xs]
   423      a, b = linear_regression_fn(xs, ys)
   424      self.assertAlmostEqual(a, 0, delta=0.02)
   425      self.assertAlmostEqual(b, sum(ys) / (len(ys) * xs[0]), delta=0.02)
   426  
   427      if test_outliers:
   428        xs = [1 + random.random() for _ in range(100)]
   429        ys = [2 * x + 1 for x in xs]
   430        a, b = linear_regression_fn(xs, ys)
   431        self.assertAlmostEqual(a, 1)
   432        self.assertAlmostEqual(b, 2)
   433  
   434        # An outlier or two doesn't affect the result.
   435        for _ in range(2):
   436          xs += [10]
   437          ys += [30]
   438          a, b = linear_regression_fn(xs, ys)
   439          self.assertAlmostEqual(a, 1)
   440          self.assertAlmostEqual(b, 2)
   441  
   442        # But enough of them, and they're no longer outliers.
   443        xs += [10] * 10
   444        ys += [30] * 10
   445        a, b = linear_regression_fn(xs, ys)
   446        self.assertLess(a, 0.5)
   447        self.assertGreater(b, 2.5)
   448  
   449    def test_no_numpy_regression(self):
   450      self._run_regression_test(
   451          util._BatchSizeEstimator.linear_regression_no_numpy, False)
   452  
   453    def test_numpy_regression(self):
   454      try:
   455        # pylint: disable=wrong-import-order, wrong-import-position
   456        import numpy as _
   457      except ImportError:
   458        self.skipTest('numpy not available')
   459      self._run_regression_test(
   460          util._BatchSizeEstimator.linear_regression_numpy, True)
   461  
   462  
   463  class IdentityWindowTest(unittest.TestCase):
   464    def test_window_preserved(self):
   465      expected_timestamp = timestamp.Timestamp(5)
   466      expected_window = window.IntervalWindow(1.0, 2.0)
   467  
   468      class AddWindowDoFn(beam.DoFn):
   469        def process(self, element):
   470          yield WindowedValue(element, expected_timestamp, [expected_window])
   471  
   472      with TestPipeline() as pipeline:
   473        data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)]
   474        expected_windows = [
   475            TestWindowedValue(kv, expected_timestamp, [expected_window])
   476            for kv in data
   477        ]
   478        before_identity = (
   479            pipeline
   480            | 'start' >> beam.Create(data)
   481            | 'add_windows' >> beam.ParDo(AddWindowDoFn()))
   482        assert_that(
   483            before_identity,
   484            equal_to(expected_windows),
   485            label='before_identity',
   486            reify_windows=True)
   487        after_identity = (
   488            before_identity
   489            | 'window' >> beam.WindowInto(
   490                beam.transforms.util._IdentityWindowFn(
   491                    coders.IntervalWindowCoder())))
   492        assert_that(
   493            after_identity,
   494            equal_to(expected_windows),
   495            label='after_identity',
   496            reify_windows=True)
   497  
   498    def test_no_window_context_fails(self):
   499      expected_timestamp = timestamp.Timestamp(5)
   500      # Assuming the default window function is window.GlobalWindows.
   501      expected_window = window.GlobalWindow()
   502  
   503      class AddTimestampDoFn(beam.DoFn):
   504        def process(self, element):
   505          yield window.TimestampedValue(element, expected_timestamp)
   506  
   507      with self.assertRaisesRegex(ValueError, r'window.*None.*add_timestamps2'):
   508        with TestPipeline() as pipeline:
   509          data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)]
   510          expected_windows = [
   511              TestWindowedValue(kv, expected_timestamp, [expected_window])
   512              for kv in data
   513          ]
   514          before_identity = (
   515              pipeline
   516              | 'start' >> beam.Create(data)
   517              | 'add_timestamps' >> beam.ParDo(AddTimestampDoFn()))
   518          assert_that(
   519              before_identity,
   520              equal_to(expected_windows),
   521              label='before_identity',
   522              reify_windows=True)
   523          after_identity = (
   524              before_identity
   525              | 'window' >> beam.WindowInto(
   526                  beam.transforms.util._IdentityWindowFn(
   527                      coders.GlobalWindowCoder()))
   528              # This DoFn will return TimestampedValues, making
   529              # WindowFn.AssignContext passed to IdentityWindowFn
   530              # contain a window of None. IdentityWindowFn should
   531              # raise an exception.
   532              | 'add_timestamps2' >> beam.ParDo(AddTimestampDoFn()))
   533          assert_that(
   534              after_identity,
   535              equal_to(expected_windows),
   536              label='after_identity',
   537              reify_windows=True)
   538  
   539  
   540  class ReshuffleTest(unittest.TestCase):
   541    def test_reshuffle_contents_unchanged(self):
   542      with TestPipeline() as pipeline:
   543        data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 3)]
   544        result = (pipeline | beam.Create(data) | beam.Reshuffle())
   545        assert_that(result, equal_to(data))
   546  
   547    def test_reshuffle_contents_unchanged_with_buckets(self):
   548      with TestPipeline() as pipeline:
   549        data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 3)]
   550        buckets = 2
   551        result = (pipeline | beam.Create(data) | beam.Reshuffle(buckets))
   552        assert_that(result, equal_to(data))
   553  
   554    def test_reshuffle_contents_unchanged_with_wrong_buckets(self):
   555      wrong_buckets = [0, -1, "wrong", 2.5]
   556      for wrong_bucket in wrong_buckets:
   557        with self.assertRaisesRegex(ValueError,
   558                                    'If `num_buckets` is set, it has to be an '
   559                                    'integer greater than 0, got %s' %
   560                                    wrong_bucket):
   561          beam.Reshuffle(wrong_bucket)
   562  
   563    def test_reshuffle_after_gbk_contents_unchanged(self):
   564      with TestPipeline() as pipeline:
   565        data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 3)]
   566        expected_result = [(1, [1, 2, 3]), (2, [1, 2]), (3, [1])]
   567  
   568        after_gbk = (
   569            pipeline
   570            | beam.Create(data)
   571            | beam.GroupByKey()
   572            | beam.MapTuple(lambda k, vs: (k, sorted(vs))))
   573        assert_that(after_gbk, equal_to(expected_result), label='after_gbk')
   574        after_reshuffle = after_gbk | beam.Reshuffle()
   575        assert_that(
   576            after_reshuffle, equal_to(expected_result), label='after_reshuffle')
   577  
   578    def test_reshuffle_timestamps_unchanged(self):
   579      with TestPipeline() as pipeline:
   580        timestamp = 5
   581        data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 3)]
   582        expected_result = [
   583            TestWindowedValue(v, timestamp, [GlobalWindow()]) for v in data
   584        ]
   585        before_reshuffle = (
   586            pipeline
   587            | 'start' >> beam.Create(data)
   588            | 'add_timestamp' >>
   589            beam.Map(lambda v: beam.window.TimestampedValue(v, timestamp)))
   590        assert_that(
   591            before_reshuffle,
   592            equal_to(expected_result),
   593            label='before_reshuffle',
   594            reify_windows=True)
   595        after_reshuffle = before_reshuffle | beam.Reshuffle()
   596        assert_that(
   597            after_reshuffle,
   598            equal_to(expected_result),
   599            label='after_reshuffle',
   600            reify_windows=True)
   601  
   602    def test_reshuffle_windows_unchanged(self):
   603      with TestPipeline() as pipeline:
   604        data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)]
   605        expected_data = [
   606            TestWindowedValue(v, t - .001, [w])
   607            for (v, t, w) in [((1, contains_in_any_order([2, 1])),
   608                               4.0,
   609                               IntervalWindow(1.0, 4.0)),
   610                              ((2, contains_in_any_order([2, 1])),
   611                               4.0,
   612                               IntervalWindow(1.0, 4.0)), (
   613                                   (3, [1]), 3.0, IntervalWindow(1.0, 3.0)), (
   614                                       (1, [4]), 6.0, IntervalWindow(4.0, 6.0))]
   615        ]
   616        before_reshuffle = (
   617            pipeline
   618            | 'start' >> beam.Create(data)
   619            | 'add_timestamp' >>
   620            beam.Map(lambda v: beam.window.TimestampedValue(v, v[1]))
   621            | 'window' >> beam.WindowInto(Sessions(gap_size=2))
   622            | 'group_by_key' >> beam.GroupByKey())
   623        assert_that(
   624            before_reshuffle,
   625            equal_to(expected_data),
   626            label='before_reshuffle',
   627            reify_windows=True)
   628        after_reshuffle = before_reshuffle | beam.Reshuffle()
   629        assert_that(
   630            after_reshuffle,
   631            equal_to(expected_data),
   632            label='after reshuffle',
   633            reify_windows=True)
   634  
   635    def test_reshuffle_window_fn_preserved(self):
   636      any_order = contains_in_any_order
   637      with TestPipeline() as pipeline:
   638        data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)]
   639        expected_windows = [
   640            TestWindowedValue(v, t, [w])
   641            for (v, t, w) in [((1, 1), 1.0, IntervalWindow(1.0, 3.0)), (
   642                (2, 1), 1.0, IntervalWindow(1.0, 3.0)), (
   643                    (3, 1), 1.0, IntervalWindow(1.0, 3.0)), (
   644                        (1, 2), 2.0, IntervalWindow(2.0, 4.0)), (
   645                            (2, 2), 2.0,
   646                            IntervalWindow(2.0, 4.0)), ((1, 4),
   647                                                        4.0,
   648                                                        IntervalWindow(4.0, 6.0))]
   649        ]
   650        expected_merged_windows = [
   651            TestWindowedValue(v, t - .001, [w])
   652            for (v, t,
   653                 w) in [((1, any_order([2, 1])), 4.0, IntervalWindow(1.0, 4.0)), (
   654                     (2, any_order([2, 1])), 4.0, IntervalWindow(1.0, 4.0)), (
   655                         (3, [1]), 3.0,
   656                         IntervalWindow(1.0, 3.0)), ((1, [4]),
   657                                                     6.0,
   658                                                     IntervalWindow(4.0, 6.0))]
   659        ]
   660        before_reshuffle = (
   661            pipeline
   662            | 'start' >> beam.Create(data)
   663            | 'add_timestamp' >> beam.Map(lambda v: TimestampedValue(v, v[1]))
   664            | 'window' >> beam.WindowInto(Sessions(gap_size=2)))
   665        assert_that(
   666            before_reshuffle,
   667            equal_to(expected_windows),
   668            label='before_reshuffle',
   669            reify_windows=True)
   670        after_reshuffle = before_reshuffle | beam.Reshuffle()
   671        assert_that(
   672            after_reshuffle,
   673            equal_to(expected_windows),
   674            label='after_reshuffle',
   675            reify_windows=True)
   676        after_group = after_reshuffle | beam.GroupByKey()
   677        assert_that(
   678            after_group,
   679            equal_to(expected_merged_windows),
   680            label='after_group',
   681            reify_windows=True)
   682  
   683    def test_reshuffle_global_window(self):
   684      with TestPipeline() as pipeline:
   685        data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)]
   686        expected_data = [(1, [1, 2, 4]), (2, [1, 2]), (3, [1])]
   687        before_reshuffle = (
   688            pipeline
   689            | beam.Create(data)
   690            | beam.WindowInto(GlobalWindows())
   691            | beam.GroupByKey()
   692            | beam.MapTuple(lambda k, vs: (k, sorted(vs))))
   693        assert_that(
   694            before_reshuffle, equal_to(expected_data), label='before_reshuffle')
   695        after_reshuffle = before_reshuffle | beam.Reshuffle()
   696        assert_that(
   697            after_reshuffle, equal_to(expected_data), label='after reshuffle')
   698  
   699    def test_reshuffle_sliding_window(self):
   700      with TestPipeline() as pipeline:
   701        data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)]
   702        window_size = 2
   703        expected_data = [(1, [1, 2, 4]), (2, [1, 2]), (3, [1])] * window_size
   704        before_reshuffle = (
   705            pipeline
   706            | beam.Create(data)
   707            | beam.WindowInto(SlidingWindows(size=window_size, period=1))
   708            | beam.GroupByKey()
   709            | beam.MapTuple(lambda k, vs: (k, sorted(vs))))
   710        assert_that(
   711            before_reshuffle, equal_to(expected_data), label='before_reshuffle')
   712        after_reshuffle = before_reshuffle | beam.Reshuffle()
   713        # If Reshuffle applies the sliding window function a second time there
   714        # should be extra values for each key.
   715        assert_that(
   716            after_reshuffle, equal_to(expected_data), label='after reshuffle')
   717  
   718    def test_reshuffle_streaming_global_window(self):
   719      options = PipelineOptions()
   720      options.view_as(StandardOptions).streaming = True
   721      with TestPipeline(options=options) as pipeline:
   722        data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)]
   723        expected_data = [(1, [1, 2, 4]), (2, [1, 2]), (3, [1])]
   724        before_reshuffle = (
   725            pipeline
   726            | beam.Create(data)
   727            | beam.WindowInto(GlobalWindows())
   728            | beam.GroupByKey()
   729            | beam.MapTuple(lambda k, vs: (k, sorted(vs))))
   730        assert_that(
   731            before_reshuffle, equal_to(expected_data), label='before_reshuffle')
   732        after_reshuffle = before_reshuffle | beam.Reshuffle()
   733        assert_that(
   734            after_reshuffle, equal_to(expected_data), label='after reshuffle')
   735  
   736    def test_reshuffle_streaming_global_window_with_buckets(self):
   737      options = PipelineOptions()
   738      options.view_as(StandardOptions).streaming = True
   739      with TestPipeline(options=options) as pipeline:
   740        data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)]
   741        expected_data = [(1, [1, 2, 4]), (2, [1, 2]), (3, [1])]
   742        buckets = 2
   743        before_reshuffle = (
   744            pipeline
   745            | beam.Create(data)
   746            | beam.WindowInto(GlobalWindows())
   747            | beam.GroupByKey()
   748            | beam.MapTuple(lambda k, vs: (k, sorted(vs))))
   749        assert_that(
   750            before_reshuffle, equal_to(expected_data), label='before_reshuffle')
   751        after_reshuffle = before_reshuffle | beam.Reshuffle(buckets)
   752        assert_that(
   753            after_reshuffle, equal_to(expected_data), label='after reshuffle')
   754  
   755    @pytest.mark.it_validatesrunner
   756    def test_reshuffle_preserves_timestamps(self):
   757      with TestPipeline() as pipeline:
   758  
   759        # Create a PCollection and assign each element with a different timestamp.
   760        before_reshuffle = (
   761            pipeline
   762            | beam.Create([
   763                {
   764                    'name': 'foo', 'timestamp': MIN_TIMESTAMP
   765                },
   766                {
   767                    'name': 'foo', 'timestamp': 0
   768                },
   769                {
   770                    'name': 'bar', 'timestamp': 33
   771                },
   772                {
   773                    'name': 'bar', 'timestamp': 0
   774                },
   775            ])
   776            | beam.Map(
   777                lambda element: beam.window.TimestampedValue(
   778                    element, element['timestamp'])))
   779  
   780        # Reshuffle the PCollection above and assign the timestamp of an element
   781        # to that element again.
   782        after_reshuffle = before_reshuffle | beam.Reshuffle()
   783  
   784        # Given an element, emits a string which contains the timestamp and the
   785        # name field of the element.
   786        def format_with_timestamp(element, timestamp=beam.DoFn.TimestampParam):
   787          t = str(timestamp)
   788          if timestamp == MIN_TIMESTAMP:
   789            t = 'MIN_TIMESTAMP'
   790          elif timestamp == MAX_TIMESTAMP:
   791            t = 'MAX_TIMESTAMP'
   792          return '{} - {}'.format(t, element['name'])
   793  
   794        # Combine each element in before_reshuffle with its timestamp.
   795        formatted_before_reshuffle = (
   796            before_reshuffle
   797            | "Get before_reshuffle timestamp" >> beam.Map(format_with_timestamp))
   798  
   799        # Combine each element in after_reshuffle with its timestamp.
   800        formatted_after_reshuffle = (
   801            after_reshuffle
   802            | "Get after_reshuffle timestamp" >> beam.Map(format_with_timestamp))
   803  
   804        expected_data = [
   805            'MIN_TIMESTAMP - foo',
   806            'Timestamp(0) - foo',
   807            'Timestamp(33) - bar',
   808            'Timestamp(0) - bar'
   809        ]
   810  
   811        # Can't compare formatted_before_reshuffle and formatted_after_reshuffle
   812        # directly, because they are deferred PCollections while equal_to only
   813        # takes a concrete argument.
   814        assert_that(
   815            formatted_before_reshuffle,
   816            equal_to(expected_data),
   817            label="formatted_before_reshuffle")
   818        assert_that(
   819            formatted_after_reshuffle,
   820            equal_to(expected_data),
   821            label="formatted_after_reshuffle")
   822  
   823  
   824  class WithKeysTest(unittest.TestCase):
   825    def setUp(self):
   826      self.l = [1, 2, 3]
   827  
   828    def test_constant_k(self):
   829      with TestPipeline() as p:
   830        pc = p | beam.Create(self.l)
   831        with_keys = pc | util.WithKeys('k')
   832      assert_that(with_keys, equal_to([('k', 1), ('k', 2), ('k', 3)], ))
   833  
   834    def test_callable_k(self):
   835      with TestPipeline() as p:
   836        pc = p | beam.Create(self.l)
   837        with_keys = pc | util.WithKeys(lambda x: x * x)
   838      assert_that(with_keys, equal_to([(1, 1), (4, 2), (9, 3)]))
   839  
   840    @staticmethod
   841    def _test_args_kwargs_fn(x, multiply, subtract):
   842      return x * multiply - subtract
   843  
   844    def test_args_kwargs_k(self):
   845      with TestPipeline() as p:
   846        pc = p | beam.Create(self.l)
   847        with_keys = pc | util.WithKeys(
   848            WithKeysTest._test_args_kwargs_fn, 2, subtract=1)
   849      assert_that(with_keys, equal_to([(1, 1), (3, 2), (5, 3)]))
   850  
   851    def test_sideinputs(self):
   852      with TestPipeline() as p:
   853        pc = p | beam.Create(self.l)
   854        si1 = AsList(p | "side input 1" >> beam.Create([1, 2, 3]))
   855        si2 = AsSingleton(p | "side input 2" >> beam.Create([10]))
   856        with_keys = pc | util.WithKeys(
   857            lambda x,
   858            the_list,
   859            the_singleton: x + sum(the_list) + the_singleton,
   860            si1,
   861            the_singleton=si2)
   862      assert_that(with_keys, equal_to([(17, 1), (18, 2), (19, 3)]))
   863  
   864  
   865  class GroupIntoBatchesTest(unittest.TestCase):
   866    NUM_ELEMENTS = 10
   867    BATCH_SIZE = 5
   868  
   869    @staticmethod
   870    def _create_test_data():
   871      scientists = [
   872          "Einstein",
   873          "Darwin",
   874          "Copernicus",
   875          "Pasteur",
   876          "Curie",
   877          "Faraday",
   878          "Newton",
   879          "Bohr",
   880          "Galilei",
   881          "Maxwell"
   882      ]
   883  
   884      data = []
   885      for i in range(GroupIntoBatchesTest.NUM_ELEMENTS):
   886        index = i % len(scientists)
   887        data.append(("key", scientists[index]))
   888      return data
   889  
   890    def test_in_global_window(self):
   891      with TestPipeline() as pipeline:
   892        collection = pipeline \
   893                     | beam.Create(GroupIntoBatchesTest._create_test_data()) \
   894                     | util.GroupIntoBatches(GroupIntoBatchesTest.BATCH_SIZE)
   895        num_batches = collection | beam.combiners.Count.Globally()
   896        assert_that(
   897            num_batches,
   898            equal_to([
   899                int(
   900                    math.ceil(
   901                        GroupIntoBatchesTest.NUM_ELEMENTS /
   902                        GroupIntoBatchesTest.BATCH_SIZE))
   903            ]))
   904  
   905    def test_with_sharded_key_in_global_window(self):
   906      with TestPipeline() as pipeline:
   907        collection = (
   908            pipeline
   909            | beam.Create(GroupIntoBatchesTest._create_test_data())
   910            | util.GroupIntoBatches.WithShardedKey(
   911                GroupIntoBatchesTest.BATCH_SIZE))
   912        num_batches = collection | beam.combiners.Count.Globally()
   913        assert_that(
   914            num_batches,
   915            equal_to([
   916                int(
   917                    math.ceil(
   918                        GroupIntoBatchesTest.NUM_ELEMENTS /
   919                        GroupIntoBatchesTest.BATCH_SIZE))
   920            ]))
   921  
   922    def test_buffering_timer_in_fixed_window_streaming(self):
   923      window_duration = 6
   924      max_buffering_duration_secs = 100
   925  
   926      start_time = timestamp.Timestamp(0)
   927      test_stream = (
   928          TestStream().add_elements([
   929              TimestampedValue(value, start_time + i) for i,
   930              value in enumerate(GroupIntoBatchesTest._create_test_data())
   931          ]).advance_processing_time(150).advance_watermark_to(
   932              start_time + window_duration).advance_watermark_to(
   933                  start_time + window_duration +
   934                  1).advance_watermark_to_infinity())
   935  
   936      with TestPipeline(options=StandardOptions(streaming=True)) as pipeline:
   937        # To trigger the processing time timer, use a fake clock with start time
   938        # being Timestamp(0).
   939        fake_clock = FakeClock(now=start_time)
   940  
   941        num_elements_per_batch = (
   942            pipeline | test_stream
   943            | "fixed window" >> WindowInto(FixedWindows(window_duration))
   944            | util.GroupIntoBatches(
   945                GroupIntoBatchesTest.BATCH_SIZE,
   946                max_buffering_duration_secs,
   947                fake_clock)
   948            | "count elements in batch" >> Map(lambda x: (None, len(x[1])))
   949            | GroupByKey()
   950            | "global window" >> WindowInto(GlobalWindows())
   951            | FlatMapTuple(lambda k, vs: vs))
   952  
   953        # Window duration is 6 and batch size is 5, so output batch size
   954        # should be 5 (flush because of batch size reached).
   955        expected_0 = 5
   956        # There is only one element left in the window so batch size
   957        # should be 1 (flush because of max buffering duration reached).
   958        expected_1 = 1
   959        # Collection has 10 elements, there are only 4 left, so batch size should
   960        # be 4 (flush because of end of window reached).
   961        expected_2 = 4
   962        assert_that(
   963            num_elements_per_batch,
   964            equal_to([expected_0, expected_1, expected_2]),
   965            "assert2")
   966  
   967    def test_buffering_timer_in_global_window_streaming(self):
   968      max_buffering_duration_secs = 42
   969  
   970      start_time = timestamp.Timestamp(0)
   971      test_stream = TestStream().advance_watermark_to(start_time)
   972      for i, value in enumerate(GroupIntoBatchesTest._create_test_data()):
   973        test_stream.add_elements(
   974            [TimestampedValue(value, start_time + i)]) \
   975          .advance_processing_time(5)
   976      test_stream.advance_watermark_to(
   977          start_time + GroupIntoBatchesTest.NUM_ELEMENTS + 1) \
   978        .advance_watermark_to_infinity()
   979  
   980      with TestPipeline(options=StandardOptions(streaming=True)) as pipeline:
   981        # Set a batch size larger than the total number of elements.
   982        # Since we're in a global window, we would have been waiting
   983        # for all the elements to arrive without the buffering time limit.
   984        batch_size = GroupIntoBatchesTest.NUM_ELEMENTS * 2
   985  
   986        # To trigger the processing time timer, use a fake clock with start time
   987        # being Timestamp(0). Since the fake clock never really advances during
   988        # the pipeline execution, meaning that the timer is always set to the same
   989        # value, the timer will be fired on every element after the first firing.
   990        fake_clock = FakeClock(now=start_time)
   991  
   992        num_elements_per_batch = (
   993            pipeline | test_stream
   994            | WindowInto(
   995                GlobalWindows(),
   996                trigger=Repeatedly(AfterCount(1)),
   997                accumulation_mode=trigger.AccumulationMode.DISCARDING)
   998            | util.GroupIntoBatches(
   999                batch_size, max_buffering_duration_secs, fake_clock)
  1000            | 'count elements in batch' >> Map(lambda x: (None, len(x[1])))
  1001            | GroupByKey()
  1002            | FlatMapTuple(lambda k, vs: vs))
  1003  
  1004        # We will flush twice when the max buffering duration is reached and when
  1005        # the global window ends.
  1006        assert_that(num_elements_per_batch, equal_to([9, 1]))
  1007  
  1008    def test_output_typehints(self):
  1009      transform = util.GroupIntoBatches.WithShardedKey(
  1010          GroupIntoBatchesTest.BATCH_SIZE)
  1011      unused_input_type = typehints.Tuple[str, str]
  1012      output_type = transform.infer_output_type(unused_input_type)
  1013      self.assertTrue(isinstance(output_type, typehints.TupleConstraint))
  1014      k, v = output_type.tuple_types
  1015      self.assertTrue(isinstance(k, ShardedKeyType))
  1016      self.assertTrue(isinstance(v, typehints.IterableTypeConstraint))
  1017  
  1018      with TestPipeline() as pipeline:
  1019        collection = (
  1020            pipeline
  1021            | beam.Create([((1, 2), 'a'), ((2, 3), 'b')])
  1022            | util.GroupIntoBatches.WithShardedKey(
  1023                GroupIntoBatchesTest.BATCH_SIZE))
  1024        self.assertTrue(
  1025            collection.element_type,
  1026            typehints.Tuple[
  1027                ShardedKeyType[typehints.Tuple[int, int]],  # type: ignore[misc]
  1028                typehints.Iterable[str]])
  1029  
  1030    def _test_runner_api_round_trip(self, transform, urn):
  1031      context = pipeline_context.PipelineContext()
  1032      proto = transform.to_runner_api(context)
  1033      self.assertEqual(urn, proto.urn)
  1034      payload = (
  1035          proto_utils.parse_Bytes(
  1036              proto.payload, beam_runner_api_pb2.GroupIntoBatchesPayload))
  1037      self.assertEqual(transform.params.batch_size, payload.batch_size)
  1038      self.assertEqual(
  1039          transform.params.max_buffering_duration_secs * 1000,
  1040          payload.max_buffering_duration_millis)
  1041  
  1042      transform_from_proto = (
  1043          transform.__class__.from_runner_api_parameter(None, payload, None))
  1044      self.assertIsInstance(transform_from_proto, transform.__class__)
  1045      self.assertEqual(transform.params, transform_from_proto.params)
  1046  
  1047    def test_runner_api(self):
  1048      batch_size = 10
  1049      max_buffering_duration_secs = [None, 0, 5]
  1050  
  1051      for duration in max_buffering_duration_secs:
  1052        self._test_runner_api_round_trip(
  1053            util.GroupIntoBatches(batch_size, duration),
  1054            common_urns.group_into_batches_components.GROUP_INTO_BATCHES.urn)
  1055      self._test_runner_api_round_trip(
  1056          util.GroupIntoBatches(batch_size),
  1057          common_urns.group_into_batches_components.GROUP_INTO_BATCHES.urn)
  1058  
  1059      for duration in max_buffering_duration_secs:
  1060        self._test_runner_api_round_trip(
  1061            util.GroupIntoBatches.WithShardedKey(batch_size, duration),
  1062            common_urns.composites.GROUP_INTO_BATCHES_WITH_SHARDED_KEY.urn)
  1063      self._test_runner_api_round_trip(
  1064          util.GroupIntoBatches.WithShardedKey(batch_size),
  1065          common_urns.composites.GROUP_INTO_BATCHES_WITH_SHARDED_KEY.urn)
  1066  
  1067  
  1068  class ToStringTest(unittest.TestCase):
  1069    def test_tostring_elements(self):
  1070      with TestPipeline() as p:
  1071        result = (p | beam.Create([1, 1, 2, 3]) | util.ToString.Element())
  1072        assert_that(result, equal_to(["1", "1", "2", "3"]))
  1073  
  1074    def test_tostring_iterables(self):
  1075      with TestPipeline() as p:
  1076        result = (
  1077            p | beam.Create([("one", "two", "three"), ("four", "five", "six")])
  1078            | util.ToString.Iterables())
  1079        assert_that(result, equal_to(["one,two,three", "four,five,six"]))
  1080  
  1081    def test_tostring_iterables_with_delimeter(self):
  1082      with TestPipeline() as p:
  1083        data = [("one", "two", "three"), ("four", "five", "six")]
  1084        result = (p | beam.Create(data) | util.ToString.Iterables("\t"))
  1085        assert_that(result, equal_to(["one\ttwo\tthree", "four\tfive\tsix"]))
  1086  
  1087    def test_tostring_kvs(self):
  1088      with TestPipeline() as p:
  1089        result = (p | beam.Create([("one", 1), ("two", 2)]) | util.ToString.Kvs())
  1090        assert_that(result, equal_to(["one,1", "two,2"]))
  1091  
  1092    def test_tostring_kvs_delimeter(self):
  1093      with TestPipeline() as p:
  1094        result = (
  1095            p | beam.Create([("one", 1), ("two", 2)]) | util.ToString.Kvs("\t"))
  1096        assert_that(result, equal_to(["one\t1", "two\t2"]))
  1097  
  1098    def test_tostring_kvs_empty_delimeter(self):
  1099      with TestPipeline() as p:
  1100        result = (
  1101            p | beam.Create([("one", 1), ("two", 2)]) | util.ToString.Kvs(""))
  1102        assert_that(result, equal_to(["one1", "two2"]))
  1103  
  1104  
  1105  class LogElementsTest(unittest.TestCase):
  1106    @pytest.fixture(scope="function")
  1107    def _capture_stdout_log(request, capsys):
  1108      with TestPipeline() as p:
  1109        result = (
  1110            p | beam.Create([
  1111                TimestampedValue(
  1112                    "event",
  1113                    datetime(2022, 10, 1, 0, 0, 0, 0,
  1114                             tzinfo=pytz.UTC).timestamp()),
  1115                TimestampedValue(
  1116                    "event",
  1117                    datetime(2022, 10, 2, 0, 0, 0, 0,
  1118                             tzinfo=pytz.UTC).timestamp()),
  1119            ])
  1120            | beam.WindowInto(FixedWindows(60))
  1121            | util.LogElements(
  1122                prefix='prefix_', with_window=True, with_timestamp=True))
  1123  
  1124      request.captured_stdout = capsys.readouterr().out
  1125      return result
  1126  
  1127    @pytest.mark.usefixtures("_capture_stdout_log")
  1128    def test_stdout_logs(self):
  1129      assert self.captured_stdout == \
  1130        ("prefix_event, timestamp='2022-10-01T00:00:00Z', "
  1131         "window(start=2022-10-01T00:00:00Z, end=2022-10-01T00:01:00Z)\n"
  1132         "prefix_event, timestamp='2022-10-02T00:00:00Z', "
  1133         "window(start=2022-10-02T00:00:00Z, end=2022-10-02T00:01:00Z)\n"), \
  1134        f'Received from stdout: {self.captured_stdout}'
  1135  
  1136    def test_ptransform_output(self):
  1137      with TestPipeline() as p:
  1138        result = (
  1139            p
  1140            | beam.Create(['a', 'b', 'c'])
  1141            | util.LogElements(prefix='prefix_'))
  1142        assert_that(result, equal_to(['a', 'b', 'c']))
  1143  
  1144  
  1145  class ReifyTest(unittest.TestCase):
  1146    def test_timestamp(self):
  1147      l = [
  1148          TimestampedValue('a', 100),
  1149          TimestampedValue('b', 200),
  1150          TimestampedValue('c', 300)
  1151      ]
  1152      expected = [
  1153          TestWindowedValue('a', 100, [GlobalWindow()]),
  1154          TestWindowedValue('b', 200, [GlobalWindow()]),
  1155          TestWindowedValue('c', 300, [GlobalWindow()])
  1156      ]
  1157      with TestPipeline() as p:
  1158        # Map(lambda x: x) PTransform is added after Create here, because when
  1159        # a PCollection of TimestampedValues is created with Create PTransform,
  1160        # the timestamps are not assigned to it. Adding a Map forces the
  1161        # PCollection to go through a DoFn so that the PCollection consists of
  1162        # the elements with timestamps assigned to them instead of a PCollection
  1163        # of TimestampedValue(element, timestamp).
  1164        pc = p | beam.Create(l) | beam.Map(lambda x: x)
  1165        reified_pc = pc | util.Reify.Timestamp()
  1166        assert_that(reified_pc, equal_to(expected), reify_windows=True)
  1167  
  1168    def test_window(self):
  1169      l = [
  1170          GlobalWindows.windowed_value('a', 100),
  1171          GlobalWindows.windowed_value('b', 200),
  1172          GlobalWindows.windowed_value('c', 300)
  1173      ]
  1174      expected = [
  1175          TestWindowedValue(('a', 100, GlobalWindow()), 100, [GlobalWindow()]),
  1176          TestWindowedValue(('b', 200, GlobalWindow()), 200, [GlobalWindow()]),
  1177          TestWindowedValue(('c', 300, GlobalWindow()), 300, [GlobalWindow()])
  1178      ]
  1179      with TestPipeline() as p:
  1180        pc = p | beam.Create(l)
  1181        # Map(lambda x: x) PTransform is added after Create here, because when
  1182        # a PCollection of WindowedValues is created with Create PTransform,
  1183        # the windows are not assigned to it. Adding a Map forces the
  1184        # PCollection to go through a DoFn so that the PCollection consists of
  1185        # the elements with timestamps assigned to them instead of a PCollection
  1186        # of WindowedValue(element, timestamp, window).
  1187        pc = pc | beam.Map(lambda x: x)
  1188        reified_pc = pc | util.Reify.Window()
  1189        assert_that(reified_pc, equal_to(expected), reify_windows=True)
  1190  
  1191    def test_timestamp_in_value(self):
  1192      l = [
  1193          TimestampedValue(('a', 1), 100),
  1194          TimestampedValue(('b', 2), 200),
  1195          TimestampedValue(('c', 3), 300)
  1196      ]
  1197      expected = [
  1198          TestWindowedValue(('a', TimestampedValue(1, 100)),
  1199                            100, [GlobalWindow()]),
  1200          TestWindowedValue(('b', TimestampedValue(2, 200)),
  1201                            200, [GlobalWindow()]),
  1202          TestWindowedValue(('c', TimestampedValue(3, 300)),
  1203                            300, [GlobalWindow()])
  1204      ]
  1205      with TestPipeline() as p:
  1206        pc = p | beam.Create(l) | beam.Map(lambda x: x)
  1207        reified_pc = pc | util.Reify.TimestampInValue()
  1208        assert_that(reified_pc, equal_to(expected), reify_windows=True)
  1209  
  1210    def test_window_in_value(self):
  1211      l = [
  1212          GlobalWindows.windowed_value(('a', 1), 100),
  1213          GlobalWindows.windowed_value(('b', 2), 200),
  1214          GlobalWindows.windowed_value(('c', 3), 300)
  1215      ]
  1216      expected = [
  1217          TestWindowedValue(('a', (1, 100, GlobalWindow())),
  1218                            100, [GlobalWindow()]),
  1219          TestWindowedValue(('b', (2, 200, GlobalWindow())),
  1220                            200, [GlobalWindow()]),
  1221          TestWindowedValue(('c', (3, 300, GlobalWindow())),
  1222                            300, [GlobalWindow()])
  1223      ]
  1224      with TestPipeline() as p:
  1225        # Map(lambda x: x) hack is used for the same reason here.
  1226        # Also, this makes the typehint on Reify.WindowInValue work.
  1227        pc = p | beam.Create(l) | beam.Map(lambda x: x)
  1228        reified_pc = pc | util.Reify.WindowInValue()
  1229        assert_that(reified_pc, equal_to(expected), reify_windows=True)
  1230  
  1231  
  1232  class RegexTest(unittest.TestCase):
  1233    def test_find(self):
  1234      with TestPipeline() as p:
  1235        result = (
  1236            p | beam.Create(["aj", "xj", "yj", "zj"])
  1237            | util.Regex.find("[xyz]"))
  1238        assert_that(result, equal_to(["x", "y", "z"]))
  1239  
  1240    def test_find_pattern(self):
  1241      with TestPipeline() as p:
  1242        rc = re.compile("[xyz]")
  1243        result = (p | beam.Create(["aj", "xj", "yj", "zj"]) | util.Regex.find(rc))
  1244        assert_that(result, equal_to(["x", "y", "z"]))
  1245  
  1246    def test_find_group(self):
  1247      with TestPipeline() as p:
  1248        result = (
  1249            p | beam.Create(["aj", "xj", "yj", "zj"])
  1250            | util.Regex.find("([xyz])j", group=1))
  1251        assert_that(result, equal_to(["x", "y", "z"]))
  1252  
  1253    def test_find_empty(self):
  1254      with TestPipeline() as p:
  1255        result = (
  1256            p | beam.Create(["a", "b", "c", "d"])
  1257            | util.Regex.find("[xyz]"))
  1258        assert_that(result, equal_to([]))
  1259  
  1260    def test_find_group_name(self):
  1261      with TestPipeline() as p:
  1262        result = (
  1263            p | beam.Create(["aj", "xj", "yj", "zj"])
  1264            | util.Regex.find("(?P<namedgroup>[xyz])j", group="namedgroup"))
  1265        assert_that(result, equal_to(["x", "y", "z"]))
  1266  
  1267    def test_find_group_name_pattern(self):
  1268      with TestPipeline() as p:
  1269        rc = re.compile("(?P<namedgroup>[xyz])j")
  1270        result = (
  1271            p | beam.Create(["aj", "xj", "yj", "zj"])
  1272            | util.Regex.find(rc, group="namedgroup"))
  1273        assert_that(result, equal_to(["x", "y", "z"]))
  1274  
  1275    def test_find_all_groups(self):
  1276      data = ["abb ax abbb", "abc qwerty abcabcd xyz"]
  1277      with TestPipeline() as p:
  1278        pcol = (p | beam.Create(data))
  1279  
  1280        assert_that(
  1281            pcol | 'with default values' >> util.Regex.find_all('a(b*)'),
  1282            equal_to([['abb', 'a', 'abbb'], ['ab', 'ab', 'ab']]),
  1283            label='CheckWithDefaultValues')
  1284  
  1285        assert_that(
  1286            pcol | 'group 1' >> util.Regex.find_all('a(b*)', 1),
  1287            equal_to([['b', 'b', 'b'], ['bb', '', 'bbb']]),
  1288            label='CheckWithGroup1')
  1289  
  1290        assert_that(
  1291            pcol | 'group 1 non empty' >> util.Regex.find_all(
  1292                'a(b*)', 1, outputEmpty=False),
  1293            equal_to([['b', 'b', 'b'], ['bb', 'bbb']]),
  1294            label='CheckGroup1NonEmpty')
  1295  
  1296        assert_that(
  1297            pcol | 'named group' >> util.Regex.find_all(
  1298                'a(?P<namedgroup>b*)', 'namedgroup'),
  1299            equal_to([['b', 'b', 'b'], ['bb', '', 'bbb']]),
  1300            label='CheckNamedGroup')
  1301  
  1302        assert_that(
  1303            pcol | 'all groups' >> util.Regex.find_all(
  1304                'a(?P<namedgroup>b*)', util.Regex.ALL),
  1305            equal_to([[('ab', 'b'), ('ab', 'b'), ('ab', 'b')],
  1306                      [('abb', 'bb'), ('a', ''), ('abbb', 'bbb')]]),
  1307            label='CheckAllGroups')
  1308  
  1309        assert_that(
  1310            pcol | 'all non empty groups' >> util.Regex.find_all(
  1311                'a(b*)', util.Regex.ALL, outputEmpty=False),
  1312            equal_to([[('ab', 'b'), ('ab', 'b'), ('ab', 'b')],
  1313                      [('abb', 'bb'), ('abbb', 'bbb')]]),
  1314            label='CheckAllNonEmptyGroups')
  1315  
  1316    def test_find_kv(self):
  1317      with TestPipeline() as p:
  1318        pcol = (p | beam.Create(['a b c d']))
  1319        assert_that(
  1320            pcol | 'key 1' >> util.Regex.find_kv(
  1321                'a (b) (c)',
  1322                1,
  1323            ),
  1324            equal_to([('b', 'a b c')]),
  1325            label='CheckKey1')
  1326  
  1327        assert_that(
  1328            pcol | 'key 1 group 1' >> util.Regex.find_kv('a (b) (c)', 1, 2),
  1329            equal_to([('b', 'c')]),
  1330            label='CheckKey1Group1')
  1331  
  1332    def test_find_kv_pattern(self):
  1333      with TestPipeline() as p:
  1334        rc = re.compile("a (b) (c)")
  1335        result = (p | beam.Create(["a b c"]) | util.Regex.find_kv(rc, 1, 2))
  1336        assert_that(result, equal_to([("b", "c")]))
  1337  
  1338    def test_find_kv_none(self):
  1339      with TestPipeline() as p:
  1340        result = (
  1341            p | beam.Create(["x y z"])
  1342            | util.Regex.find_kv("a (b) (c)", 1, 2))
  1343        assert_that(result, equal_to([]))
  1344  
  1345    def test_match(self):
  1346      with TestPipeline() as p:
  1347        result = (
  1348            p | beam.Create(["a", "x", "y", "z"])
  1349            | util.Regex.matches("[xyz]"))
  1350        assert_that(result, equal_to(["x", "y", "z"]))
  1351  
  1352      with TestPipeline() as p:
  1353        result = (
  1354            p | beam.Create(["a", "ax", "yby", "zzc"])
  1355            | util.Regex.matches("[xyz]"))
  1356        assert_that(result, equal_to(["y", "z"]))
  1357  
  1358    def test_match_entire_line(self):
  1359      with TestPipeline() as p:
  1360        result = (
  1361            p | beam.Create(["a", "x", "y", "ay", "zz"])
  1362            | util.Regex.matches("[xyz]$"))
  1363        assert_that(result, equal_to(["x", "y"]))
  1364  
  1365    def test_match_pattern(self):
  1366      with TestPipeline() as p:
  1367        rc = re.compile("[xyz]")
  1368        result = (p | beam.Create(["a", "x", "y", "z"]) | util.Regex.matches(rc))
  1369        assert_that(result, equal_to(["x", "y", "z"]))
  1370  
  1371    def test_match_none(self):
  1372      with TestPipeline() as p:
  1373        result = (
  1374            p | beam.Create(["a", "b", "c", "d"])
  1375            | util.Regex.matches("[xyz]"))
  1376        assert_that(result, equal_to([]))
  1377  
  1378    def test_match_group(self):
  1379      with TestPipeline() as p:
  1380        result = (
  1381            p | beam.Create(["a", "x xxx", "x yyy", "x zzz"])
  1382            | util.Regex.matches("x ([xyz]*)", 1))
  1383        assert_that(result, equal_to(("xxx", "yyy", "zzz")))
  1384  
  1385    def test_match_group_name(self):
  1386      with TestPipeline() as p:
  1387        result = (
  1388            p | beam.Create(["a", "x xxx", "x yyy", "x zzz"])
  1389            | util.Regex.matches("x (?P<namedgroup>[xyz]*)", 'namedgroup'))
  1390        assert_that(result, equal_to(("xxx", "yyy", "zzz")))
  1391  
  1392    def test_match_group_name_pattern(self):
  1393      with TestPipeline() as p:
  1394        rc = re.compile("x (?P<namedgroup>[xyz]*)")
  1395        result = (
  1396            p | beam.Create(["a", "x xxx", "x yyy", "x zzz"])
  1397            | util.Regex.matches(rc, 'namedgroup'))
  1398        assert_that(result, equal_to(("xxx", "yyy", "zzz")))
  1399  
  1400    def test_match_group_empty(self):
  1401      with TestPipeline() as p:
  1402        result = (
  1403            p | beam.Create(["a", "b", "c", "d"])
  1404            | util.Regex.matches("x (?P<namedgroup>[xyz]*)", 'namedgroup'))
  1405        assert_that(result, equal_to([]))
  1406  
  1407    def test_all_matched(self):
  1408      with TestPipeline() as p:
  1409        result = (
  1410            p | beam.Create(["a x", "x x", "y y", "z z"])
  1411            | util.Regex.all_matches("([xyz]) ([xyz])"))
  1412        expected_result = [["x x", "x", "x"], ["y y", "y", "y"],
  1413                           ["z z", "z", "z"]]
  1414        assert_that(result, equal_to(expected_result))
  1415  
  1416    def test_all_matched_pattern(self):
  1417      with TestPipeline() as p:
  1418        rc = re.compile("([xyz]) ([xyz])")
  1419        result = (
  1420            p | beam.Create(["a x", "x x", "y y", "z z"])
  1421            | util.Regex.all_matches(rc))
  1422        expected_result = [["x x", "x", "x"], ["y y", "y", "y"],
  1423                           ["z z", "z", "z"]]
  1424        assert_that(result, equal_to(expected_result))
  1425  
  1426    def test_match_group_kv(self):
  1427      with TestPipeline() as p:
  1428        result = (
  1429            p | beam.Create(["a b c"])
  1430            | util.Regex.matches_kv("a (b) (c)", 1, 2))
  1431        assert_that(result, equal_to([("b", "c")]))
  1432  
  1433    def test_match_group_kv_pattern(self):
  1434      with TestPipeline() as p:
  1435        rc = re.compile("a (b) (c)")
  1436        pcol = (p | beam.Create(["a b c"]))
  1437        assert_that(
  1438            pcol | 'key 1' >> util.Regex.matches_kv(rc, 1),
  1439            equal_to([("b", "a b c")]),
  1440            label="CheckKey1")
  1441  
  1442        assert_that(
  1443            pcol | 'key 1 group 2' >> util.Regex.matches_kv(rc, 1, 2),
  1444            equal_to([("b", "c")]),
  1445            label="CheckKey1Group2")
  1446  
  1447    def test_match_group_kv_none(self):
  1448      with TestPipeline() as p:
  1449        result = (
  1450            p | beam.Create(["x y z"])
  1451            | util.Regex.matches_kv("a (b) (c)", 1, 2))
  1452        assert_that(result, equal_to([]))
  1453  
  1454    def test_match_kv_group_names(self):
  1455      with TestPipeline() as p:
  1456        result = (
  1457            p | beam.Create(["a b c"]) | util.Regex.matches_kv(
  1458                "a (?P<keyname>b) (?P<valuename>c)", 'keyname', 'valuename'))
  1459        assert_that(result, equal_to([("b", "c")]))
  1460  
  1461    def test_match_kv_group_names_pattern(self):
  1462      with TestPipeline() as p:
  1463        rc = re.compile("a (?P<keyname>b) (?P<valuename>c)")
  1464        result = (
  1465            p | beam.Create(["a b c"])
  1466            | util.Regex.matches_kv(rc, 'keyname', 'valuename'))
  1467        assert_that(result, equal_to([("b", "c")]))
  1468  
  1469    def test_match_kv_group_name_none(self):
  1470      with TestPipeline() as p:
  1471        result = (
  1472            p | beam.Create(["x y z"]) | util.Regex.matches_kv(
  1473                "a (?P<keyname>b) (?P<valuename>c)", 'keyname', 'valuename'))
  1474        assert_that(result, equal_to([]))
  1475  
  1476    def test_replace_all(self):
  1477      with TestPipeline() as p:
  1478        result = (
  1479            p | beam.Create(["xj", "yj", "zj"])
  1480            | util.Regex.replace_all("[xyz]", "new"))
  1481        assert_that(result, equal_to(["newj", "newj", "newj"]))
  1482  
  1483    def test_replace_all_mixed(self):
  1484      with TestPipeline() as p:
  1485        result = (
  1486            p | beam.Create(["abc", "xj", "yj", "zj", "def"])
  1487            | util.Regex.replace_all("[xyz]", 'new'))
  1488        assert_that(result, equal_to(["abc", "newj", "newj", "newj", "def"]))
  1489  
  1490    def test_replace_all_mixed_pattern(self):
  1491      with TestPipeline() as p:
  1492        rc = re.compile("[xyz]")
  1493        result = (
  1494            p | beam.Create(["abc", "xj", "yj", "zj", "def"])
  1495            | util.Regex.replace_all(rc, 'new'))
  1496        assert_that(result, equal_to(["abc", "newj", "newj", "newj", "def"]))
  1497  
  1498    def test_replace_first(self):
  1499      with TestPipeline() as p:
  1500        result = (
  1501            p | beam.Create(["xjx", "yjy", "zjz"])
  1502            | util.Regex.replace_first("[xyz]", 'new'))
  1503        assert_that(result, equal_to(["newjx", "newjy", "newjz"]))
  1504  
  1505    def test_replace_first_mixed(self):
  1506      with TestPipeline() as p:
  1507        result = (
  1508            p | beam.Create(["abc", "xjx", "yjy", "zjz", "def"])
  1509            | util.Regex.replace_first("[xyz]", 'new'))
  1510        assert_that(result, equal_to(["abc", "newjx", "newjy", "newjz", "def"]))
  1511  
  1512    def test_replace_first_mixed_pattern(self):
  1513      with TestPipeline() as p:
  1514        rc = re.compile("[xyz]")
  1515        result = (
  1516            p | beam.Create(["abc", "xjx", "yjy", "zjz", "def"])
  1517            | util.Regex.replace_first(rc, 'new'))
  1518        assert_that(result, equal_to(["abc", "newjx", "newjy", "newjz", "def"]))
  1519  
  1520    def test_split(self):
  1521      with TestPipeline() as p:
  1522        data = ["The  quick   brown fox jumps over    the lazy dog"]
  1523        result = (p | beam.Create(data) | util.Regex.split("\\W+"))
  1524        expected_result = [[
  1525            "The", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog"
  1526        ]]
  1527        assert_that(result, equal_to(expected_result))
  1528  
  1529    def test_split_pattern(self):
  1530      with TestPipeline() as p:
  1531        data = ["The  quick   brown fox jumps over    the lazy dog"]
  1532        rc = re.compile("\\W+")
  1533        result = (p | beam.Create(data) | util.Regex.split(rc))
  1534        expected_result = [[
  1535            "The", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog"
  1536        ]]
  1537        assert_that(result, equal_to(expected_result))
  1538  
  1539    def test_split_with_empty(self):
  1540      with TestPipeline() as p:
  1541        data = ["The  quick   brown fox jumps over    the lazy dog"]
  1542        result = (p | beam.Create(data) | util.Regex.split("\\s", True))
  1543        expected_result = [[
  1544            'The',
  1545            '',
  1546            'quick',
  1547            '',
  1548            '',
  1549            'brown',
  1550            'fox',
  1551            'jumps',
  1552            'over',
  1553            '',
  1554            '',
  1555            '',
  1556            'the',
  1557            'lazy',
  1558            'dog'
  1559        ]]
  1560        assert_that(result, equal_to(expected_result))
  1561  
  1562    def test_split_without_empty(self):
  1563      with TestPipeline() as p:
  1564        data = ["The  quick   brown fox jumps over    the lazy dog"]
  1565        result = (p | beam.Create(data) | util.Regex.split("\\s", False))
  1566        expected_result = [[
  1567            "The", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog"
  1568        ]]
  1569        assert_that(result, equal_to(expected_result))
  1570  
  1571  
  1572  if __name__ == '__main__':
  1573    logging.getLogger().setLevel(logging.INFO)
  1574    unittest.main()