github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/utils_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 # pytype: skip-file 19 20 import unittest 21 22 import mock 23 24 from apache_beam.io import OffsetRangeTracker 25 from apache_beam.io import source_test_utils 26 from apache_beam.io.utils import CountingSource 27 28 29 class CountingSourceTest(unittest.TestCase): 30 def setUp(self): 31 self.source = CountingSource(10) 32 33 def test_estimate_size(self): 34 self.assertEqual(10, self.source.estimate_size()) 35 36 @mock.patch('apache_beam.io.utils.OffsetRangeTracker') 37 def test_get_range_tracker(self, mock_tracker): 38 _ = self.source.get_range_tracker(None, None) 39 mock_tracker.assert_called_with(0, 10) 40 _ = self.source.get_range_tracker(3, 7) 41 mock_tracker.assert_called_with(3, 7) 42 43 def test_read(self): 44 tracker = OffsetRangeTracker(3, 6) 45 res = list(self.source.read(tracker)) 46 self.assertEqual([3, 4, 5], res) 47 48 def test_split(self): 49 for size in [1, 3, 10]: 50 splits = list(self.source.split(desired_bundle_size=size)) 51 52 reference_info = (self.source, None, None) 53 sources_info = ([ 54 (split.source, split.start_position, split.stop_position) 55 for split in splits 56 ]) 57 source_test_utils.assert_sources_equal_reference_source( 58 reference_info, sources_info) 59 60 def test_dynamic_work_rebalancing(self): 61 splits = list(self.source.split(desired_bundle_size=20)) 62 assert len(splits) == 1 63 source_test_utils.assert_split_at_fraction_exhaustive( 64 splits[0].source, splits[0].start_position, splits[0].stop_position) 65 66 67 if __name__ == '__main__': 68 unittest.main()