github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/iobase_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Unit tests for classes in iobase.py.""" 19 20 # pytype: skip-file 21 22 import unittest 23 24 import mock 25 26 import apache_beam as beam 27 from apache_beam.io.concat_source import ConcatSource 28 from apache_beam.io.concat_source_test import RangeSource 29 from apache_beam.io import iobase 30 from apache_beam.io import range_trackers 31 from apache_beam.io.iobase import SourceBundle 32 from apache_beam.options.pipeline_options import DebugOptions 33 from apache_beam.testing.util import assert_that 34 from apache_beam.testing.util import equal_to 35 36 37 class SDFBoundedSourceRestrictionProviderTest(unittest.TestCase): 38 def setUp(self): 39 self.initial_range_start = 0 40 self.initial_range_stop = 4 41 self.initial_range_source = RangeSource( 42 self.initial_range_start, self.initial_range_stop) 43 self.sdf_restriction_provider = ( 44 iobase._SDFBoundedSourceRestrictionProvider(desired_chunk_size=2)) 45 46 def test_initial_restriction(self): 47 element = self.initial_range_source 48 restriction = (self.sdf_restriction_provider.initial_restriction(element)) 49 self.assertTrue( 50 isinstance(restriction, iobase._SDFBoundedSourceRestriction)) 51 self.assertTrue(isinstance(restriction._source_bundle, SourceBundle)) 52 self.assertEqual( 53 self.initial_range_start, restriction._source_bundle.start_position) 54 self.assertEqual( 55 self.initial_range_stop, restriction._source_bundle.stop_position) 56 self.assertTrue(isinstance(restriction._source_bundle.source, RangeSource)) 57 self.assertEqual(restriction._range_tracker, None) 58 59 def test_create_tracker(self): 60 expected_start = 1 61 expected_stop = 3 62 source_bundle = SourceBundle( 63 expected_stop - expected_start, 64 RangeSource(1, 3), 65 expected_start, 66 expected_stop) 67 restriction_tracker = ( 68 self.sdf_restriction_provider.create_tracker( 69 iobase._SDFBoundedSourceRestriction(source_bundle))) 70 self.assertTrue( 71 isinstance( 72 restriction_tracker, iobase._SDFBoundedSourceRestrictionTracker)) 73 self.assertEqual(expected_start, restriction_tracker.start_pos()) 74 self.assertEqual(expected_stop, restriction_tracker.stop_pos()) 75 76 def test_simple_source_split(self): 77 element = self.initial_range_source 78 restriction = (self.sdf_restriction_provider.initial_restriction(element)) 79 expect_splits = [(0, 2), (2, 4)] 80 split_bundles = list( 81 self.sdf_restriction_provider.split(element, restriction)) 82 self.assertTrue( 83 all( 84 isinstance(bundle._source_bundle, SourceBundle) 85 for bundle in split_bundles)) 86 87 splits = ([( 88 bundle._source_bundle.start_position, 89 bundle._source_bundle.stop_position) for bundle in split_bundles]) 90 self.assertEqual(expect_splits, list(splits)) 91 92 def test_concat_source_split(self): 93 element = self.initial_range_source 94 initial_concat_source = ConcatSource([self.initial_range_source]) 95 sdf_concat_restriction_provider = ( 96 iobase._SDFBoundedSourceRestrictionProvider(desired_chunk_size=2)) 97 restriction = (self.sdf_restriction_provider.initial_restriction(element)) 98 expect_splits = [(0, 2), (2, 4)] 99 split_bundles = list( 100 sdf_concat_restriction_provider.split( 101 initial_concat_source, restriction)) 102 self.assertTrue( 103 all( 104 isinstance(bundle._source_bundle, SourceBundle) 105 for bundle in split_bundles)) 106 splits = ([( 107 bundle._source_bundle.start_position, 108 bundle._source_bundle.stop_position) for bundle in split_bundles]) 109 self.assertEqual(expect_splits, list(splits)) 110 111 def test_restriction_size(self): 112 element = self.initial_range_source 113 restriction = (self.sdf_restriction_provider.initial_restriction(element)) 114 split_1, split_2 = self.sdf_restriction_provider.split(element, restriction) 115 split_1_size = self.sdf_restriction_provider.restriction_size( 116 element, split_1) 117 split_2_size = self.sdf_restriction_provider.restriction_size( 118 element, split_2) 119 self.assertEqual(2, split_1_size) 120 self.assertEqual(2, split_2_size) 121 122 123 class SDFBoundedSourceRestrictionTrackerTest(unittest.TestCase): 124 def setUp(self): 125 self.initial_start_pos = 0 126 self.initial_stop_pos = 4 127 source_bundle = SourceBundle( 128 self.initial_stop_pos - self.initial_start_pos, 129 RangeSource(self.initial_start_pos, self.initial_stop_pos), 130 self.initial_start_pos, 131 self.initial_stop_pos) 132 self.sdf_restriction_tracker = ( 133 iobase._SDFBoundedSourceRestrictionTracker( 134 iobase._SDFBoundedSourceRestriction(source_bundle))) 135 136 def test_current_restriction_before_split(self): 137 current_restriction = (self.sdf_restriction_tracker.current_restriction()) 138 self.assertEqual( 139 self.initial_start_pos, 140 current_restriction._source_bundle.start_position) 141 self.assertEqual( 142 self.initial_stop_pos, current_restriction._source_bundle.stop_position) 143 self.assertEqual( 144 self.initial_start_pos, 145 current_restriction._range_tracker.start_position()) 146 self.assertEqual( 147 self.initial_stop_pos, 148 current_restriction._range_tracker.stop_position()) 149 150 def test_current_restriction_after_split(self): 151 fraction_of_remainder = 0.5 152 self.sdf_restriction_tracker.try_claim(1) 153 expected_restriction, _ = ( 154 self.sdf_restriction_tracker.try_split(fraction_of_remainder)) 155 current_restriction = self.sdf_restriction_tracker.current_restriction() 156 self.assertEqual( 157 expected_restriction._source_bundle, current_restriction._source_bundle) 158 self.assertTrue(current_restriction._range_tracker) 159 160 def test_try_split_at_remainder(self): 161 fraction_of_remainder = 0.4 162 expected_primary = (0, 2, 2.0) 163 expected_residual = (2, 4, 2.0) 164 self.sdf_restriction_tracker.try_claim(0) 165 actual_primary, actual_residual = ( 166 self.sdf_restriction_tracker.try_split(fraction_of_remainder)) 167 self.assertEqual( 168 expected_primary, 169 ( 170 actual_primary._source_bundle.start_position, 171 actual_primary._source_bundle.stop_position, 172 actual_primary._source_bundle.weight)) 173 self.assertEqual( 174 expected_residual, 175 ( 176 actual_residual._source_bundle.start_position, 177 actual_residual._source_bundle.stop_position, 178 actual_residual._source_bundle.weight)) 179 self.assertEqual( 180 actual_primary._source_bundle.weight, 181 self.sdf_restriction_tracker.current_restriction().weight()) 182 183 def test_try_split_with_any_exception(self): 184 source_bundle = SourceBundle( 185 range_trackers.OffsetRangeTracker.OFFSET_INFINITY, 186 RangeSource(0, range_trackers.OffsetRangeTracker.OFFSET_INFINITY), 187 0, 188 range_trackers.OffsetRangeTracker.OFFSET_INFINITY) 189 self.sdf_restriction_tracker = ( 190 iobase._SDFBoundedSourceRestrictionTracker( 191 iobase._SDFBoundedSourceRestriction(source_bundle))) 192 self.sdf_restriction_tracker.try_claim(0) 193 self.assertIsNone(self.sdf_restriction_tracker.try_split(0.5)) 194 195 196 class UseSdfBoundedSourcesTests(unittest.TestCase): 197 def _run_sdf_wrapper_pipeline(self, source, expected_values): 198 with beam.Pipeline() as p: 199 experiments = (p._options.view_as(DebugOptions).experiments or []) 200 201 # Setup experiment option to enable using SDFBoundedSourceWrapper 202 if 'beam_fn_api' not in experiments: 203 # Required so mocking below doesn't mock Create used in assert_that. 204 experiments.append('beam_fn_api') 205 206 p._options.view_as(DebugOptions).experiments = experiments 207 208 actual = p | beam.io.Read(source) 209 assert_that(actual, equal_to(expected_values)) 210 211 @mock.patch('apache_beam.io.iobase.SDFBoundedSourceReader.expand') 212 def test_sdf_wrapper_overrides_read(self, sdf_wrapper_mock_expand): 213 def _fake_wrapper_expand(pbegin): 214 return pbegin | beam.Map(lambda x: 'fake') 215 216 sdf_wrapper_mock_expand.side_effect = _fake_wrapper_expand 217 self._run_sdf_wrapper_pipeline(RangeSource(0, 4), ['fake']) 218 219 def test_sdf_wrap_range_source(self): 220 self._run_sdf_wrapper_pipeline(RangeSource(0, 4), [0, 1, 2, 3]) 221 222 223 if __name__ == '__main__': 224 unittest.main()