github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/worker/sideinputs_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Tests for side input utilities.""" 19 20 # pytype: skip-file 21 22 import logging 23 import time 24 import unittest 25 26 import mock 27 28 from apache_beam.coders import observable 29 from apache_beam.runners.worker import sideinputs 30 31 32 def strip_windows(iterator): 33 return [wv.value for wv in iterator] 34 35 36 class FakeSource(object): 37 def __init__(self, items, notify_observers=False): 38 self.items = items 39 self._should_notify_observers = notify_observers 40 41 def reader(self): 42 return FakeSourceReader(self.items, self._should_notify_observers) 43 44 45 class FakeSourceReader(observable.ObservableMixin): 46 def __init__(self, items, notify_observers=False): 47 super().__init__() 48 self.items = items 49 self.entered = False 50 self.exited = False 51 self._should_notify_observers = notify_observers 52 53 def __iter__(self): 54 if self._should_notify_observers: 55 self.notify_observers(len(self.items), is_record_size=True) 56 for item in self.items: 57 yield item 58 59 def __enter__(self): 60 self.entered = True 61 return self 62 63 def __exit__(self, exception_type, exception_value, traceback): 64 self.exited = True 65 66 @property 67 def returns_windowed_values(self): 68 return False 69 70 71 class PrefetchingSourceIteratorTest(unittest.TestCase): 72 def test_single_source_iterator_fn(self): 73 sources = [ 74 FakeSource([0, 1, 2, 3, 4, 5]), 75 ] 76 iterator_fn = sideinputs.get_iterator_fn_for_sources( 77 sources, max_reader_threads=2) 78 assert list(strip_windows(iterator_fn())) == list(range(6)) 79 80 def test_bytes_read_are_reported(self): 81 mock_read_counter = mock.MagicMock() 82 source_records = ['a', 'b', 'c', 'd'] 83 sources = [ 84 FakeSource(source_records, notify_observers=True), 85 ] 86 iterator_fn = sideinputs.get_iterator_fn_for_sources( 87 sources, max_reader_threads=3, read_counter=mock_read_counter) 88 assert list(strip_windows(iterator_fn())) == source_records 89 mock_read_counter.add_bytes_read.assert_called_with(4) 90 91 def test_multiple_sources_iterator_fn(self): 92 sources = [ 93 FakeSource([0]), 94 FakeSource([1, 2, 3, 4, 5]), 95 FakeSource([]), 96 FakeSource([6, 7, 8, 9, 10]), 97 ] 98 iterator_fn = sideinputs.get_iterator_fn_for_sources( 99 sources, max_reader_threads=3) 100 assert sorted(strip_windows(iterator_fn())) == list(range(11)) 101 102 def test_multiple_sources_single_reader_iterator_fn(self): 103 sources = [ 104 FakeSource([0]), 105 FakeSource([1, 2, 3, 4, 5]), 106 FakeSource([]), 107 FakeSource([6, 7, 8, 9, 10]), 108 ] 109 iterator_fn = sideinputs.get_iterator_fn_for_sources( 110 sources, max_reader_threads=1) 111 assert list(strip_windows(iterator_fn())) == list(range(11)) 112 113 def test_source_iterator_single_source_exception(self): 114 class MyException(Exception): 115 pass 116 117 def exception_generator(): 118 yield 0 119 raise MyException('I am an exception!') 120 121 sources = [ 122 FakeSource(exception_generator()), 123 ] 124 iterator_fn = sideinputs.get_iterator_fn_for_sources(sources) 125 seen = set() 126 with self.assertRaises(MyException): 127 for value in iterator_fn(): 128 seen.add(value.value) 129 self.assertEqual(sorted(seen), [0]) 130 131 def test_source_iterator_fn_exception(self): 132 class MyException(Exception): 133 pass 134 135 def exception_generator(): 136 yield 0 137 time.sleep(0.1) 138 raise MyException('I am an exception!') 139 140 def perpetual_generator(value): 141 while True: 142 yield value 143 time.sleep(0.1) 144 145 sources = [ 146 FakeSource(perpetual_generator(1)), 147 FakeSource(perpetual_generator(2)), 148 FakeSource(perpetual_generator(3)), 149 FakeSource(perpetual_generator(4)), 150 FakeSource(exception_generator()), 151 ] 152 iterator_fn = sideinputs.get_iterator_fn_for_sources(sources) 153 seen = set() 154 with self.assertRaises(MyException): 155 for value in iterator_fn(): 156 seen.add(value.value) 157 self.assertEqual(sorted(seen), list(range(5))) 158 159 160 class EmulatedCollectionsTest(unittest.TestCase): 161 def test_emulated_iterable(self): 162 def _iterable_fn(): 163 for i in range(10): 164 yield i 165 166 iterable = sideinputs.EmulatedIterable(_iterable_fn) 167 # Check that multiple iterations are supported. 168 for _ in range(0, 5): 169 for i, j in enumerate(iterable): 170 self.assertEqual(i, j) 171 172 def test_large_iterable_values(self): 173 # Here, we create a large collection that would be too big for memory- 174 # constained test environments, but should be under the memory limit if 175 # materialized one at a time. 176 def _iterable_fn(): 177 for i in range(10): 178 yield ('%d' % i) * (200 * 1024 * 1024) 179 180 iterable = sideinputs.EmulatedIterable(_iterable_fn) 181 # Check that multiple iterations are supported. 182 for _ in range(0, 3): 183 for i, j in enumerate(iterable): 184 self.assertEqual(('%d' % i) * (200 * 1024 * 1024), j) 185 186 187 if __name__ == '__main__': 188 logging.getLogger().setLevel(logging.INFO) 189 unittest.main()