github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/concat_source_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Unit tests for the sources framework.""" 19 # pytype: skip-file 20 21 import logging 22 import unittest 23 24 import apache_beam as beam 25 from apache_beam.io import iobase 26 from apache_beam.io import range_trackers 27 from apache_beam.io import source_test_utils 28 from apache_beam.io.concat_source import ConcatSource 29 from apache_beam.testing.test_pipeline import TestPipeline 30 from apache_beam.testing.util import assert_that 31 from apache_beam.testing.util import equal_to 32 33 __all__ = ['RangeSource'] 34 35 36 class RangeSource(iobase.BoundedSource): 37 38 __hash__ = None # type: ignore[assignment] 39 40 def __init__(self, start, end, split_freq=1): 41 assert start <= end 42 self._start = start 43 self._end = end 44 self._split_freq = split_freq 45 46 def _normalize(self, start_position, end_position): 47 return ( 48 self._start if start_position is None else start_position, 49 self._end if end_position is None else end_position) 50 51 def _round_up(self, index): 52 """Rounds up to the nearest mulitple of split_freq.""" 53 return index - index % -self._split_freq 54 55 def estimate_size(self): 56 return self._end - self._start 57 58 def split(self, desired_bundle_size, start_position=None, end_position=None): 59 start, end = self._normalize(start_position, end_position) 60 for sub_start in range(start, end, desired_bundle_size): 61 sub_end = min(self._end, sub_start + desired_bundle_size) 62 yield iobase.SourceBundle( 63 sub_end - sub_start, 64 RangeSource(sub_start, sub_end, self._split_freq), 65 sub_start, 66 sub_end) 67 68 def get_range_tracker(self, start_position, end_position): 69 start, end = self._normalize(start_position, end_position) 70 return range_trackers.OffsetRangeTracker(start, end) 71 72 def read(self, range_tracker): 73 for k in range(self._round_up(range_tracker.start_position()), 74 self._round_up(range_tracker.stop_position())): 75 if k % self._split_freq == 0: 76 if not range_tracker.try_claim(k): 77 return 78 yield k 79 80 # For testing 81 def __eq__(self, other): 82 return ( 83 type(self) == type(other) and self._start == other._start and 84 self._end == other._end and self._split_freq == other._split_freq) 85 86 87 class ConcatSourceTest(unittest.TestCase): 88 def test_range_source(self): 89 source_test_utils.assert_split_at_fraction_exhaustive(RangeSource(0, 10, 3)) 90 91 def test_conact_source(self): 92 source = ConcatSource([ 93 RangeSource(0, 4), 94 RangeSource(4, 8), 95 RangeSource(8, 12), 96 RangeSource(12, 16), 97 ]) 98 self.assertEqual( 99 list(source.read(source.get_range_tracker())), list(range(16))) 100 self.assertEqual( 101 list(source.read(source.get_range_tracker((1, None), (2, 10)))), 102 list(range(4, 10))) 103 range_tracker = source.get_range_tracker(None, None) 104 self.assertEqual(range_tracker.position_at_fraction(0), (0, 0)) 105 self.assertEqual(range_tracker.position_at_fraction(.5), (2, 8)) 106 self.assertEqual(range_tracker.position_at_fraction(.625), (2, 10)) 107 108 # Simulate a read. 109 self.assertEqual(range_tracker.try_claim((0, None)), True) 110 self.assertEqual(range_tracker.sub_range_tracker(0).try_claim(2), True) 111 self.assertEqual(range_tracker.fraction_consumed(), 0.125) 112 113 self.assertEqual(range_tracker.try_claim((1, None)), True) 114 self.assertEqual(range_tracker.sub_range_tracker(1).try_claim(6), True) 115 self.assertEqual(range_tracker.fraction_consumed(), 0.375) 116 self.assertEqual(range_tracker.try_split((0, 1)), None) 117 self.assertEqual(range_tracker.try_split((1, 5)), None) 118 119 self.assertEqual(range_tracker.try_split((3, 14)), ((3, None), 0.75)) 120 self.assertEqual(range_tracker.try_claim((3, None)), False) 121 self.assertEqual(range_tracker.sub_range_tracker(1).try_claim(7), True) 122 self.assertEqual(range_tracker.try_claim((2, None)), True) 123 self.assertEqual(range_tracker.sub_range_tracker(2).try_claim(9), True) 124 125 self.assertEqual(range_tracker.try_split((2, 8)), None) 126 self.assertEqual(range_tracker.try_split((2, 11)), ((2, 11), 11. / 12)) 127 self.assertEqual(range_tracker.sub_range_tracker(2).try_claim(10), True) 128 self.assertEqual(range_tracker.sub_range_tracker(2).try_claim(11), False) 129 130 def test_fraction_consumed_at_end(self): 131 source = ConcatSource([ 132 RangeSource(0, 2), 133 RangeSource(2, 4), 134 ]) 135 range_tracker = source.get_range_tracker((2, None), None) 136 self.assertEqual(range_tracker.fraction_consumed(), 1.0) 137 138 def test_estimate_size(self): 139 source = ConcatSource([ 140 RangeSource(0, 10), 141 RangeSource(10, 100), 142 RangeSource(100, 1000), 143 ]) 144 self.assertEqual(source.estimate_size(), 1000) 145 146 def test_position_at_fration(self): 147 ranges = [(0, 4), (4, 16), (16, 24), (24, 32)] 148 source = ConcatSource([ 149 iobase.SourceBundle((range[1] - range[0]) / 32., 150 RangeSource(*range), 151 None, 152 None) for range in ranges 153 ]) 154 155 range_tracker = source.get_range_tracker() 156 self.assertEqual(range_tracker.position_at_fraction(0), (0, 0)) 157 self.assertEqual(range_tracker.position_at_fraction(.01), (0, 1)) 158 self.assertEqual(range_tracker.position_at_fraction(.1), (0, 4)) 159 self.assertEqual(range_tracker.position_at_fraction(.125), (1, 4)) 160 self.assertEqual(range_tracker.position_at_fraction(.2), (1, 7)) 161 self.assertEqual(range_tracker.position_at_fraction(.7), (2, 23)) 162 self.assertEqual(range_tracker.position_at_fraction(.75), (3, 24)) 163 self.assertEqual(range_tracker.position_at_fraction(.8), (3, 26)) 164 self.assertEqual(range_tracker.position_at_fraction(1), (4, None)) 165 166 range_tracker = source.get_range_tracker((1, None), (3, None)) 167 self.assertEqual(range_tracker.position_at_fraction(0), (1, 4)) 168 self.assertEqual(range_tracker.position_at_fraction(.01), (1, 5)) 169 self.assertEqual(range_tracker.position_at_fraction(.5), (1, 14)) 170 self.assertEqual(range_tracker.position_at_fraction(.599), (1, 16)) 171 self.assertEqual(range_tracker.position_at_fraction(.601), (2, 17)) 172 self.assertEqual(range_tracker.position_at_fraction(1), (3, None)) 173 174 def test_empty_source(self): 175 read_all = source_test_utils.read_from_source 176 177 empty = RangeSource(0, 0) 178 self.assertEqual(read_all(ConcatSource([])), []) 179 self.assertEqual(read_all(ConcatSource([empty])), []) 180 self.assertEqual(read_all(ConcatSource([empty, empty])), []) 181 182 range10 = RangeSource(0, 10) 183 self.assertEqual(read_all(ConcatSource([range10]), (0, None), (0, 0)), []) 184 self.assertEqual(read_all(ConcatSource([range10]), (0, 10), (1, None)), []) 185 self.assertEqual( 186 read_all(ConcatSource([range10, range10]), (0, 10), (1, 0)), []) 187 188 def test_single_source(self): 189 read_all = source_test_utils.read_from_source 190 191 range10 = RangeSource(0, 10) 192 self.assertEqual(read_all(ConcatSource([range10])), list(range(10))) 193 self.assertEqual( 194 read_all(ConcatSource([range10]), (0, 5)), list(range(5, 10))) 195 self.assertEqual( 196 read_all(ConcatSource([range10]), None, (0, 5)), list(range(5))) 197 198 def test_source_with_empty_ranges(self): 199 read_all = source_test_utils.read_from_source 200 201 empty = RangeSource(0, 0) 202 self.assertEqual(read_all(empty), []) 203 204 range10 = RangeSource(0, 10) 205 self.assertEqual( 206 read_all(ConcatSource([empty, empty, range10])), list(range(10))) 207 self.assertEqual( 208 read_all(ConcatSource([empty, range10, empty])), list(range(10))) 209 self.assertEqual( 210 read_all(ConcatSource([range10, empty, range10, empty])), 211 list(range(10)) + list(range(10))) 212 213 def test_source_with_empty_ranges_exhastive(self): 214 empty = RangeSource(0, 0) 215 source = ConcatSource([ 216 empty, 217 RangeSource(0, 10), 218 empty, 219 empty, 220 RangeSource(10, 13), 221 RangeSource(13, 17), 222 empty, 223 ]) 224 source_test_utils.assert_split_at_fraction_exhaustive(source) 225 226 def test_run_concat_direct(self): 227 source = ConcatSource([ 228 RangeSource(0, 10), 229 RangeSource(10, 100), 230 RangeSource(100, 1000), 231 ]) 232 with TestPipeline() as pipeline: 233 pcoll = pipeline | beam.io.Read(source) 234 assert_that(pcoll, equal_to(list(range(1000)))) 235 236 def test_conact_source_exhaustive(self): 237 source = ConcatSource([ 238 RangeSource(0, 10), 239 RangeSource(100, 110), 240 RangeSource(1000, 1010), 241 ]) 242 source_test_utils.assert_split_at_fraction_exhaustive(source) 243 244 245 if __name__ == '__main__': 246 logging.getLogger().setLevel(logging.INFO) 247 unittest.main()