github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/utils_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  # pytype: skip-file
    19  
    20  import unittest
    21  
    22  import mock
    23  
    24  from apache_beam.io import OffsetRangeTracker
    25  from apache_beam.io import source_test_utils
    26  from apache_beam.io.utils import CountingSource
    27  
    28  
    29  class CountingSourceTest(unittest.TestCase):
    30    def setUp(self):
    31      self.source = CountingSource(10)
    32  
    33    def test_estimate_size(self):
    34      self.assertEqual(10, self.source.estimate_size())
    35  
    36    @mock.patch('apache_beam.io.utils.OffsetRangeTracker')
    37    def test_get_range_tracker(self, mock_tracker):
    38      _ = self.source.get_range_tracker(None, None)
    39      mock_tracker.assert_called_with(0, 10)
    40      _ = self.source.get_range_tracker(3, 7)
    41      mock_tracker.assert_called_with(3, 7)
    42  
    43    def test_read(self):
    44      tracker = OffsetRangeTracker(3, 6)
    45      res = list(self.source.read(tracker))
    46      self.assertEqual([3, 4, 5], res)
    47  
    48    def test_split(self):
    49      for size in [1, 3, 10]:
    50        splits = list(self.source.split(desired_bundle_size=size))
    51  
    52        reference_info = (self.source, None, None)
    53        sources_info = ([
    54            (split.source, split.start_position, split.stop_position)
    55            for split in splits
    56        ])
    57        source_test_utils.assert_sources_equal_reference_source(
    58            reference_info, sources_info)
    59  
    60    def test_dynamic_work_rebalancing(self):
    61      splits = list(self.source.split(desired_bundle_size=20))
    62      assert len(splits) == 1
    63      source_test_utils.assert_split_at_fraction_exhaustive(
    64          splits[0].source, splits[0].start_position, splits[0].stop_position)
    65  
    66  
    67  if __name__ == '__main__':
    68    unittest.main()