github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/worker/sideinputs_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Tests for side input utilities."""
    19  
    20  # pytype: skip-file
    21  
    22  import logging
    23  import time
    24  import unittest
    25  
    26  import mock
    27  
    28  from apache_beam.coders import observable
    29  from apache_beam.runners.worker import sideinputs
    30  
    31  
    32  def strip_windows(iterator):
    33    return [wv.value for wv in iterator]
    34  
    35  
    36  class FakeSource(object):
    37    def __init__(self, items, notify_observers=False):
    38      self.items = items
    39      self._should_notify_observers = notify_observers
    40  
    41    def reader(self):
    42      return FakeSourceReader(self.items, self._should_notify_observers)
    43  
    44  
    45  class FakeSourceReader(observable.ObservableMixin):
    46    def __init__(self, items, notify_observers=False):
    47      super().__init__()
    48      self.items = items
    49      self.entered = False
    50      self.exited = False
    51      self._should_notify_observers = notify_observers
    52  
    53    def __iter__(self):
    54      if self._should_notify_observers:
    55        self.notify_observers(len(self.items), is_record_size=True)
    56      for item in self.items:
    57        yield item
    58  
    59    def __enter__(self):
    60      self.entered = True
    61      return self
    62  
    63    def __exit__(self, exception_type, exception_value, traceback):
    64      self.exited = True
    65  
    66    @property
    67    def returns_windowed_values(self):
    68      return False
    69  
    70  
    71  class PrefetchingSourceIteratorTest(unittest.TestCase):
    72    def test_single_source_iterator_fn(self):
    73      sources = [
    74          FakeSource([0, 1, 2, 3, 4, 5]),
    75      ]
    76      iterator_fn = sideinputs.get_iterator_fn_for_sources(
    77          sources, max_reader_threads=2)
    78      assert list(strip_windows(iterator_fn())) == list(range(6))
    79  
    80    def test_bytes_read_are_reported(self):
    81      mock_read_counter = mock.MagicMock()
    82      source_records = ['a', 'b', 'c', 'd']
    83      sources = [
    84          FakeSource(source_records, notify_observers=True),
    85      ]
    86      iterator_fn = sideinputs.get_iterator_fn_for_sources(
    87          sources, max_reader_threads=3, read_counter=mock_read_counter)
    88      assert list(strip_windows(iterator_fn())) == source_records
    89      mock_read_counter.add_bytes_read.assert_called_with(4)
    90  
    91    def test_multiple_sources_iterator_fn(self):
    92      sources = [
    93          FakeSource([0]),
    94          FakeSource([1, 2, 3, 4, 5]),
    95          FakeSource([]),
    96          FakeSource([6, 7, 8, 9, 10]),
    97      ]
    98      iterator_fn = sideinputs.get_iterator_fn_for_sources(
    99          sources, max_reader_threads=3)
   100      assert sorted(strip_windows(iterator_fn())) == list(range(11))
   101  
   102    def test_multiple_sources_single_reader_iterator_fn(self):
   103      sources = [
   104          FakeSource([0]),
   105          FakeSource([1, 2, 3, 4, 5]),
   106          FakeSource([]),
   107          FakeSource([6, 7, 8, 9, 10]),
   108      ]
   109      iterator_fn = sideinputs.get_iterator_fn_for_sources(
   110          sources, max_reader_threads=1)
   111      assert list(strip_windows(iterator_fn())) == list(range(11))
   112  
   113    def test_source_iterator_single_source_exception(self):
   114      class MyException(Exception):
   115        pass
   116  
   117      def exception_generator():
   118        yield 0
   119        raise MyException('I am an exception!')
   120  
   121      sources = [
   122          FakeSource(exception_generator()),
   123      ]
   124      iterator_fn = sideinputs.get_iterator_fn_for_sources(sources)
   125      seen = set()
   126      with self.assertRaises(MyException):
   127        for value in iterator_fn():
   128          seen.add(value.value)
   129      self.assertEqual(sorted(seen), [0])
   130  
   131    def test_source_iterator_fn_exception(self):
   132      class MyException(Exception):
   133        pass
   134  
   135      def exception_generator():
   136        yield 0
   137        time.sleep(0.1)
   138        raise MyException('I am an exception!')
   139  
   140      def perpetual_generator(value):
   141        while True:
   142          yield value
   143          time.sleep(0.1)
   144  
   145      sources = [
   146          FakeSource(perpetual_generator(1)),
   147          FakeSource(perpetual_generator(2)),
   148          FakeSource(perpetual_generator(3)),
   149          FakeSource(perpetual_generator(4)),
   150          FakeSource(exception_generator()),
   151      ]
   152      iterator_fn = sideinputs.get_iterator_fn_for_sources(sources)
   153      seen = set()
   154      with self.assertRaises(MyException):
   155        for value in iterator_fn():
   156          seen.add(value.value)
   157      self.assertEqual(sorted(seen), list(range(5)))
   158  
   159  
   160  class EmulatedCollectionsTest(unittest.TestCase):
   161    def test_emulated_iterable(self):
   162      def _iterable_fn():
   163        for i in range(10):
   164          yield i
   165  
   166      iterable = sideinputs.EmulatedIterable(_iterable_fn)
   167      # Check that multiple iterations are supported.
   168      for _ in range(0, 5):
   169        for i, j in enumerate(iterable):
   170          self.assertEqual(i, j)
   171  
   172    def test_large_iterable_values(self):
   173      # Here, we create a large collection that would be too big for memory-
   174      # constained test environments, but should be under the memory limit if
   175      # materialized one at a time.
   176      def _iterable_fn():
   177        for i in range(10):
   178          yield ('%d' % i) * (200 * 1024 * 1024)
   179  
   180      iterable = sideinputs.EmulatedIterable(_iterable_fn)
   181      # Check that multiple iterations are supported.
   182      for _ in range(0, 3):
   183        for i, j in enumerate(iterable):
   184          self.assertEqual(('%d' % i) * (200 * 1024 * 1024), j)
   185  
   186  
   187  if __name__ == '__main__':
   188    logging.getLogger().setLevel(logging.INFO)
   189    unittest.main()