github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/concat_source_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Unit tests for the sources framework."""
    19  # pytype: skip-file
    20  
    21  import logging
    22  import unittest
    23  
    24  import apache_beam as beam
    25  from apache_beam.io import iobase
    26  from apache_beam.io import range_trackers
    27  from apache_beam.io import source_test_utils
    28  from apache_beam.io.concat_source import ConcatSource
    29  from apache_beam.testing.test_pipeline import TestPipeline
    30  from apache_beam.testing.util import assert_that
    31  from apache_beam.testing.util import equal_to
    32  
    33  __all__ = ['RangeSource']
    34  
    35  
    36  class RangeSource(iobase.BoundedSource):
    37  
    38    __hash__ = None  # type: ignore[assignment]
    39  
    40    def __init__(self, start, end, split_freq=1):
    41      assert start <= end
    42      self._start = start
    43      self._end = end
    44      self._split_freq = split_freq
    45  
    46    def _normalize(self, start_position, end_position):
    47      return (
    48          self._start if start_position is None else start_position,
    49          self._end if end_position is None else end_position)
    50  
    51    def _round_up(self, index):
    52      """Rounds up to the nearest mulitple of split_freq."""
    53      return index - index % -self._split_freq
    54  
    55    def estimate_size(self):
    56      return self._end - self._start
    57  
    58    def split(self, desired_bundle_size, start_position=None, end_position=None):
    59      start, end = self._normalize(start_position, end_position)
    60      for sub_start in range(start, end, desired_bundle_size):
    61        sub_end = min(self._end, sub_start + desired_bundle_size)
    62        yield iobase.SourceBundle(
    63            sub_end - sub_start,
    64            RangeSource(sub_start, sub_end, self._split_freq),
    65            sub_start,
    66            sub_end)
    67  
    68    def get_range_tracker(self, start_position, end_position):
    69      start, end = self._normalize(start_position, end_position)
    70      return range_trackers.OffsetRangeTracker(start, end)
    71  
    72    def read(self, range_tracker):
    73      for k in range(self._round_up(range_tracker.start_position()),
    74                     self._round_up(range_tracker.stop_position())):
    75        if k % self._split_freq == 0:
    76          if not range_tracker.try_claim(k):
    77            return
    78        yield k
    79  
    80    # For testing
    81    def __eq__(self, other):
    82      return (
    83          type(self) == type(other) and self._start == other._start and
    84          self._end == other._end and self._split_freq == other._split_freq)
    85  
    86  
    87  class ConcatSourceTest(unittest.TestCase):
    88    def test_range_source(self):
    89      source_test_utils.assert_split_at_fraction_exhaustive(RangeSource(0, 10, 3))
    90  
    91    def test_conact_source(self):
    92      source = ConcatSource([
    93          RangeSource(0, 4),
    94          RangeSource(4, 8),
    95          RangeSource(8, 12),
    96          RangeSource(12, 16),
    97      ])
    98      self.assertEqual(
    99          list(source.read(source.get_range_tracker())), list(range(16)))
   100      self.assertEqual(
   101          list(source.read(source.get_range_tracker((1, None), (2, 10)))),
   102          list(range(4, 10)))
   103      range_tracker = source.get_range_tracker(None, None)
   104      self.assertEqual(range_tracker.position_at_fraction(0), (0, 0))
   105      self.assertEqual(range_tracker.position_at_fraction(.5), (2, 8))
   106      self.assertEqual(range_tracker.position_at_fraction(.625), (2, 10))
   107  
   108      # Simulate a read.
   109      self.assertEqual(range_tracker.try_claim((0, None)), True)
   110      self.assertEqual(range_tracker.sub_range_tracker(0).try_claim(2), True)
   111      self.assertEqual(range_tracker.fraction_consumed(), 0.125)
   112  
   113      self.assertEqual(range_tracker.try_claim((1, None)), True)
   114      self.assertEqual(range_tracker.sub_range_tracker(1).try_claim(6), True)
   115      self.assertEqual(range_tracker.fraction_consumed(), 0.375)
   116      self.assertEqual(range_tracker.try_split((0, 1)), None)
   117      self.assertEqual(range_tracker.try_split((1, 5)), None)
   118  
   119      self.assertEqual(range_tracker.try_split((3, 14)), ((3, None), 0.75))
   120      self.assertEqual(range_tracker.try_claim((3, None)), False)
   121      self.assertEqual(range_tracker.sub_range_tracker(1).try_claim(7), True)
   122      self.assertEqual(range_tracker.try_claim((2, None)), True)
   123      self.assertEqual(range_tracker.sub_range_tracker(2).try_claim(9), True)
   124  
   125      self.assertEqual(range_tracker.try_split((2, 8)), None)
   126      self.assertEqual(range_tracker.try_split((2, 11)), ((2, 11), 11. / 12))
   127      self.assertEqual(range_tracker.sub_range_tracker(2).try_claim(10), True)
   128      self.assertEqual(range_tracker.sub_range_tracker(2).try_claim(11), False)
   129  
   130    def test_fraction_consumed_at_end(self):
   131      source = ConcatSource([
   132          RangeSource(0, 2),
   133          RangeSource(2, 4),
   134      ])
   135      range_tracker = source.get_range_tracker((2, None), None)
   136      self.assertEqual(range_tracker.fraction_consumed(), 1.0)
   137  
   138    def test_estimate_size(self):
   139      source = ConcatSource([
   140          RangeSource(0, 10),
   141          RangeSource(10, 100),
   142          RangeSource(100, 1000),
   143      ])
   144      self.assertEqual(source.estimate_size(), 1000)
   145  
   146    def test_position_at_fration(self):
   147      ranges = [(0, 4), (4, 16), (16, 24), (24, 32)]
   148      source = ConcatSource([
   149          iobase.SourceBundle((range[1] - range[0]) / 32.,
   150                              RangeSource(*range),
   151                              None,
   152                              None) for range in ranges
   153      ])
   154  
   155      range_tracker = source.get_range_tracker()
   156      self.assertEqual(range_tracker.position_at_fraction(0), (0, 0))
   157      self.assertEqual(range_tracker.position_at_fraction(.01), (0, 1))
   158      self.assertEqual(range_tracker.position_at_fraction(.1), (0, 4))
   159      self.assertEqual(range_tracker.position_at_fraction(.125), (1, 4))
   160      self.assertEqual(range_tracker.position_at_fraction(.2), (1, 7))
   161      self.assertEqual(range_tracker.position_at_fraction(.7), (2, 23))
   162      self.assertEqual(range_tracker.position_at_fraction(.75), (3, 24))
   163      self.assertEqual(range_tracker.position_at_fraction(.8), (3, 26))
   164      self.assertEqual(range_tracker.position_at_fraction(1), (4, None))
   165  
   166      range_tracker = source.get_range_tracker((1, None), (3, None))
   167      self.assertEqual(range_tracker.position_at_fraction(0), (1, 4))
   168      self.assertEqual(range_tracker.position_at_fraction(.01), (1, 5))
   169      self.assertEqual(range_tracker.position_at_fraction(.5), (1, 14))
   170      self.assertEqual(range_tracker.position_at_fraction(.599), (1, 16))
   171      self.assertEqual(range_tracker.position_at_fraction(.601), (2, 17))
   172      self.assertEqual(range_tracker.position_at_fraction(1), (3, None))
   173  
   174    def test_empty_source(self):
   175      read_all = source_test_utils.read_from_source
   176  
   177      empty = RangeSource(0, 0)
   178      self.assertEqual(read_all(ConcatSource([])), [])
   179      self.assertEqual(read_all(ConcatSource([empty])), [])
   180      self.assertEqual(read_all(ConcatSource([empty, empty])), [])
   181  
   182      range10 = RangeSource(0, 10)
   183      self.assertEqual(read_all(ConcatSource([range10]), (0, None), (0, 0)), [])
   184      self.assertEqual(read_all(ConcatSource([range10]), (0, 10), (1, None)), [])
   185      self.assertEqual(
   186          read_all(ConcatSource([range10, range10]), (0, 10), (1, 0)), [])
   187  
   188    def test_single_source(self):
   189      read_all = source_test_utils.read_from_source
   190  
   191      range10 = RangeSource(0, 10)
   192      self.assertEqual(read_all(ConcatSource([range10])), list(range(10)))
   193      self.assertEqual(
   194          read_all(ConcatSource([range10]), (0, 5)), list(range(5, 10)))
   195      self.assertEqual(
   196          read_all(ConcatSource([range10]), None, (0, 5)), list(range(5)))
   197  
   198    def test_source_with_empty_ranges(self):
   199      read_all = source_test_utils.read_from_source
   200  
   201      empty = RangeSource(0, 0)
   202      self.assertEqual(read_all(empty), [])
   203  
   204      range10 = RangeSource(0, 10)
   205      self.assertEqual(
   206          read_all(ConcatSource([empty, empty, range10])), list(range(10)))
   207      self.assertEqual(
   208          read_all(ConcatSource([empty, range10, empty])), list(range(10)))
   209      self.assertEqual(
   210          read_all(ConcatSource([range10, empty, range10, empty])),
   211          list(range(10)) + list(range(10)))
   212  
   213    def test_source_with_empty_ranges_exhastive(self):
   214      empty = RangeSource(0, 0)
   215      source = ConcatSource([
   216          empty,
   217          RangeSource(0, 10),
   218          empty,
   219          empty,
   220          RangeSource(10, 13),
   221          RangeSource(13, 17),
   222          empty,
   223      ])
   224      source_test_utils.assert_split_at_fraction_exhaustive(source)
   225  
   226    def test_run_concat_direct(self):
   227      source = ConcatSource([
   228          RangeSource(0, 10),
   229          RangeSource(10, 100),
   230          RangeSource(100, 1000),
   231      ])
   232      with TestPipeline() as pipeline:
   233        pcoll = pipeline | beam.io.Read(source)
   234        assert_that(pcoll, equal_to(list(range(1000))))
   235  
   236    def test_conact_source_exhaustive(self):
   237      source = ConcatSource([
   238          RangeSource(0, 10),
   239          RangeSource(100, 110),
   240          RangeSource(1000, 1010),
   241      ])
   242      source_test_utils.assert_split_at_fraction_exhaustive(source)
   243  
   244  
   245  if __name__ == '__main__':
   246    logging.getLogger().setLevel(logging.INFO)
   247    unittest.main()