github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/iobase_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Unit tests for classes in iobase.py."""
    19  
    20  # pytype: skip-file
    21  
    22  import unittest
    23  
    24  import mock
    25  
    26  import apache_beam as beam
    27  from apache_beam.io.concat_source import ConcatSource
    28  from apache_beam.io.concat_source_test import RangeSource
    29  from apache_beam.io import iobase
    30  from apache_beam.io import range_trackers
    31  from apache_beam.io.iobase import SourceBundle
    32  from apache_beam.options.pipeline_options import DebugOptions
    33  from apache_beam.testing.util import assert_that
    34  from apache_beam.testing.util import equal_to
    35  
    36  
    37  class SDFBoundedSourceRestrictionProviderTest(unittest.TestCase):
    38    def setUp(self):
    39      self.initial_range_start = 0
    40      self.initial_range_stop = 4
    41      self.initial_range_source = RangeSource(
    42          self.initial_range_start, self.initial_range_stop)
    43      self.sdf_restriction_provider = (
    44          iobase._SDFBoundedSourceRestrictionProvider(desired_chunk_size=2))
    45  
    46    def test_initial_restriction(self):
    47      element = self.initial_range_source
    48      restriction = (self.sdf_restriction_provider.initial_restriction(element))
    49      self.assertTrue(
    50          isinstance(restriction, iobase._SDFBoundedSourceRestriction))
    51      self.assertTrue(isinstance(restriction._source_bundle, SourceBundle))
    52      self.assertEqual(
    53          self.initial_range_start, restriction._source_bundle.start_position)
    54      self.assertEqual(
    55          self.initial_range_stop, restriction._source_bundle.stop_position)
    56      self.assertTrue(isinstance(restriction._source_bundle.source, RangeSource))
    57      self.assertEqual(restriction._range_tracker, None)
    58  
    59    def test_create_tracker(self):
    60      expected_start = 1
    61      expected_stop = 3
    62      source_bundle = SourceBundle(
    63          expected_stop - expected_start,
    64          RangeSource(1, 3),
    65          expected_start,
    66          expected_stop)
    67      restriction_tracker = (
    68          self.sdf_restriction_provider.create_tracker(
    69              iobase._SDFBoundedSourceRestriction(source_bundle)))
    70      self.assertTrue(
    71          isinstance(
    72              restriction_tracker, iobase._SDFBoundedSourceRestrictionTracker))
    73      self.assertEqual(expected_start, restriction_tracker.start_pos())
    74      self.assertEqual(expected_stop, restriction_tracker.stop_pos())
    75  
    76    def test_simple_source_split(self):
    77      element = self.initial_range_source
    78      restriction = (self.sdf_restriction_provider.initial_restriction(element))
    79      expect_splits = [(0, 2), (2, 4)]
    80      split_bundles = list(
    81          self.sdf_restriction_provider.split(element, restriction))
    82      self.assertTrue(
    83          all(
    84              isinstance(bundle._source_bundle, SourceBundle)
    85              for bundle in split_bundles))
    86  
    87      splits = ([(
    88          bundle._source_bundle.start_position,
    89          bundle._source_bundle.stop_position) for bundle in split_bundles])
    90      self.assertEqual(expect_splits, list(splits))
    91  
    92    def test_concat_source_split(self):
    93      element = self.initial_range_source
    94      initial_concat_source = ConcatSource([self.initial_range_source])
    95      sdf_concat_restriction_provider = (
    96          iobase._SDFBoundedSourceRestrictionProvider(desired_chunk_size=2))
    97      restriction = (self.sdf_restriction_provider.initial_restriction(element))
    98      expect_splits = [(0, 2), (2, 4)]
    99      split_bundles = list(
   100          sdf_concat_restriction_provider.split(
   101              initial_concat_source, restriction))
   102      self.assertTrue(
   103          all(
   104              isinstance(bundle._source_bundle, SourceBundle)
   105              for bundle in split_bundles))
   106      splits = ([(
   107          bundle._source_bundle.start_position,
   108          bundle._source_bundle.stop_position) for bundle in split_bundles])
   109      self.assertEqual(expect_splits, list(splits))
   110  
   111    def test_restriction_size(self):
   112      element = self.initial_range_source
   113      restriction = (self.sdf_restriction_provider.initial_restriction(element))
   114      split_1, split_2 = self.sdf_restriction_provider.split(element, restriction)
   115      split_1_size = self.sdf_restriction_provider.restriction_size(
   116          element, split_1)
   117      split_2_size = self.sdf_restriction_provider.restriction_size(
   118          element, split_2)
   119      self.assertEqual(2, split_1_size)
   120      self.assertEqual(2, split_2_size)
   121  
   122  
   123  class SDFBoundedSourceRestrictionTrackerTest(unittest.TestCase):
   124    def setUp(self):
   125      self.initial_start_pos = 0
   126      self.initial_stop_pos = 4
   127      source_bundle = SourceBundle(
   128          self.initial_stop_pos - self.initial_start_pos,
   129          RangeSource(self.initial_start_pos, self.initial_stop_pos),
   130          self.initial_start_pos,
   131          self.initial_stop_pos)
   132      self.sdf_restriction_tracker = (
   133          iobase._SDFBoundedSourceRestrictionTracker(
   134              iobase._SDFBoundedSourceRestriction(source_bundle)))
   135  
   136    def test_current_restriction_before_split(self):
   137      current_restriction = (self.sdf_restriction_tracker.current_restriction())
   138      self.assertEqual(
   139          self.initial_start_pos,
   140          current_restriction._source_bundle.start_position)
   141      self.assertEqual(
   142          self.initial_stop_pos, current_restriction._source_bundle.stop_position)
   143      self.assertEqual(
   144          self.initial_start_pos,
   145          current_restriction._range_tracker.start_position())
   146      self.assertEqual(
   147          self.initial_stop_pos,
   148          current_restriction._range_tracker.stop_position())
   149  
   150    def test_current_restriction_after_split(self):
   151      fraction_of_remainder = 0.5
   152      self.sdf_restriction_tracker.try_claim(1)
   153      expected_restriction, _ = (
   154          self.sdf_restriction_tracker.try_split(fraction_of_remainder))
   155      current_restriction = self.sdf_restriction_tracker.current_restriction()
   156      self.assertEqual(
   157          expected_restriction._source_bundle, current_restriction._source_bundle)
   158      self.assertTrue(current_restriction._range_tracker)
   159  
   160    def test_try_split_at_remainder(self):
   161      fraction_of_remainder = 0.4
   162      expected_primary = (0, 2, 2.0)
   163      expected_residual = (2, 4, 2.0)
   164      self.sdf_restriction_tracker.try_claim(0)
   165      actual_primary, actual_residual = (
   166          self.sdf_restriction_tracker.try_split(fraction_of_remainder))
   167      self.assertEqual(
   168          expected_primary,
   169          (
   170              actual_primary._source_bundle.start_position,
   171              actual_primary._source_bundle.stop_position,
   172              actual_primary._source_bundle.weight))
   173      self.assertEqual(
   174          expected_residual,
   175          (
   176              actual_residual._source_bundle.start_position,
   177              actual_residual._source_bundle.stop_position,
   178              actual_residual._source_bundle.weight))
   179      self.assertEqual(
   180          actual_primary._source_bundle.weight,
   181          self.sdf_restriction_tracker.current_restriction().weight())
   182  
   183    def test_try_split_with_any_exception(self):
   184      source_bundle = SourceBundle(
   185          range_trackers.OffsetRangeTracker.OFFSET_INFINITY,
   186          RangeSource(0, range_trackers.OffsetRangeTracker.OFFSET_INFINITY),
   187          0,
   188          range_trackers.OffsetRangeTracker.OFFSET_INFINITY)
   189      self.sdf_restriction_tracker = (
   190          iobase._SDFBoundedSourceRestrictionTracker(
   191              iobase._SDFBoundedSourceRestriction(source_bundle)))
   192      self.sdf_restriction_tracker.try_claim(0)
   193      self.assertIsNone(self.sdf_restriction_tracker.try_split(0.5))
   194  
   195  
   196  class UseSdfBoundedSourcesTests(unittest.TestCase):
   197    def _run_sdf_wrapper_pipeline(self, source, expected_values):
   198      with beam.Pipeline() as p:
   199        experiments = (p._options.view_as(DebugOptions).experiments or [])
   200  
   201        # Setup experiment option to enable using SDFBoundedSourceWrapper
   202        if 'beam_fn_api' not in experiments:
   203          # Required so mocking below doesn't mock Create used in assert_that.
   204          experiments.append('beam_fn_api')
   205  
   206        p._options.view_as(DebugOptions).experiments = experiments
   207  
   208        actual = p | beam.io.Read(source)
   209        assert_that(actual, equal_to(expected_values))
   210  
   211    @mock.patch('apache_beam.io.iobase.SDFBoundedSourceReader.expand')
   212    def test_sdf_wrapper_overrides_read(self, sdf_wrapper_mock_expand):
   213      def _fake_wrapper_expand(pbegin):
   214        return pbegin | beam.Map(lambda x: 'fake')
   215  
   216      sdf_wrapper_mock_expand.side_effect = _fake_wrapper_expand
   217      self._run_sdf_wrapper_pipeline(RangeSource(0, 4), ['fake'])
   218  
   219    def test_sdf_wrap_range_source(self):
   220      self._run_sdf_wrapper_pipeline(RangeSource(0, 4), [0, 1, 2, 3])
   221  
   222  
   223  if __name__ == '__main__':
   224    unittest.main()