github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/range_trackers_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Unit tests for the range_trackers module."""
    19  # pytype: skip-file
    20  
    21  import copy
    22  import logging
    23  import math
    24  import unittest
    25  from typing import Optional
    26  from typing import Union
    27  
    28  from apache_beam.io import range_trackers
    29  
    30  
    31  class OffsetRangeTrackerTest(unittest.TestCase):
    32    def test_try_return_record_simple_sparse(self):
    33      tracker = range_trackers.OffsetRangeTracker(100, 200)
    34      self.assertTrue(tracker.try_claim(110))
    35      self.assertTrue(tracker.try_claim(140))
    36      self.assertTrue(tracker.try_claim(183))
    37      self.assertFalse(tracker.try_claim(210))
    38  
    39    def test_try_return_record_simple_dense(self):
    40      tracker = range_trackers.OffsetRangeTracker(3, 6)
    41      self.assertTrue(tracker.try_claim(3))
    42      self.assertTrue(tracker.try_claim(4))
    43      self.assertTrue(tracker.try_claim(5))
    44      self.assertFalse(tracker.try_claim(6))
    45  
    46    def test_try_claim_update_last_attempt(self):
    47      tracker = range_trackers.OffsetRangeTracker(1, 2)
    48      self.assertTrue(tracker.try_claim(1))
    49      self.assertEqual(1, tracker.last_attempted_record_start)
    50  
    51      self.assertFalse(tracker.try_claim(3))
    52      self.assertEqual(3, tracker.last_attempted_record_start)
    53  
    54      self.assertFalse(tracker.try_claim(6))
    55      self.assertEqual(6, tracker.last_attempted_record_start)
    56  
    57      with self.assertRaises(Exception):
    58        tracker.try_claim(6)
    59  
    60    def test_set_current_position(self):
    61      tracker = range_trackers.OffsetRangeTracker(0, 6)
    62      self.assertTrue(tracker.try_claim(2))
    63      # Cannot set current position before successful claimed pos.
    64      with self.assertRaises(Exception):
    65        tracker.set_current_position(1)
    66  
    67      self.assertFalse(tracker.try_claim(10))
    68      tracker.set_current_position(11)
    69      self.assertEqual(10, tracker.last_attempted_record_start)
    70      self.assertEqual(11, tracker.last_record_start)
    71  
    72    def test_try_return_record_continuous_until_split_point(self):
    73      tracker = range_trackers.OffsetRangeTracker(9, 18)
    74      # Return records with gaps of 2; every 3rd record is a split point.
    75      self.assertTrue(tracker.try_claim(10))
    76      tracker.set_current_position(12)
    77      tracker.set_current_position(14)
    78      self.assertTrue(tracker.try_claim(16))
    79      # Out of range, but not a split point...
    80      tracker.set_current_position(18)
    81      tracker.set_current_position(20)
    82      # Out of range AND a split point.
    83      self.assertFalse(tracker.try_claim(22))
    84  
    85    def test_split_at_offset_fails_if_unstarted(self):
    86      tracker = range_trackers.OffsetRangeTracker(100, 200)
    87      self.assertFalse(tracker.try_split(150))
    88  
    89    def test_split_at_offset(self):
    90      tracker = range_trackers.OffsetRangeTracker(100, 200)
    91      self.assertTrue(tracker.try_claim(110))
    92      # Example positions we shouldn't split at, when last record starts at 110:
    93      self.assertFalse(tracker.try_split(109))
    94      self.assertFalse(tracker.try_split(110))
    95      self.assertFalse(tracker.try_split(200))
    96      self.assertFalse(tracker.try_split(210))
    97      # Example positions we *should* split at:
    98      self.assertTrue(copy.copy(tracker).try_split(111))
    99      self.assertTrue(copy.copy(tracker).try_split(129))
   100      self.assertTrue(copy.copy(tracker).try_split(130))
   101      self.assertTrue(copy.copy(tracker).try_split(131))
   102      self.assertTrue(copy.copy(tracker).try_split(150))
   103      self.assertTrue(copy.copy(tracker).try_split(199))
   104  
   105      # If we split at 170 and then at 150:
   106      self.assertTrue(tracker.try_split(170))
   107      self.assertTrue(tracker.try_split(150))
   108      # Should be able  to return a record starting before the new stop offset.
   109      # Returning records starting at the same offset is ok.
   110      self.assertTrue(copy.copy(tracker).try_claim(135))
   111      self.assertTrue(copy.copy(tracker).try_claim(135))
   112      # Should be able to return a record starting right before the new stop
   113      # offset.
   114      self.assertTrue(copy.copy(tracker).try_claim(149))
   115      # Should not be able to return a record starting at or after the new stop
   116      # offset.
   117      self.assertFalse(tracker.try_claim(150))
   118      self.assertFalse(tracker.try_claim(151))
   119      # Should accept non-splitpoint records starting after stop offset.
   120      tracker.set_current_position(152)
   121      tracker.set_current_position(160)
   122      tracker.set_current_position(171)
   123  
   124    def test_get_position_for_fraction_dense(self):
   125      # Represents positions 3, 4, 5.
   126      tracker = range_trackers.OffsetRangeTracker(3, 6)
   127  
   128      # Position must be an integer type.
   129      self.assertTrue(isinstance(tracker.position_at_fraction(0.0), int))
   130      # [3, 3) represents 0.0 of [3, 6)
   131      self.assertEqual(3, tracker.position_at_fraction(0.0))
   132      # [3, 4) represents up to 1/3 of [3, 6)
   133      self.assertEqual(4, tracker.position_at_fraction(1.0 / 6))
   134      self.assertEqual(4, tracker.position_at_fraction(0.333))
   135      # [3, 5) represents up to 2/3 of [3, 6)
   136      self.assertEqual(5, tracker.position_at_fraction(0.334))
   137      self.assertEqual(5, tracker.position_at_fraction(0.666))
   138      # Any fraction consumed over 2/3 means the whole [3, 6) has been consumed.
   139      self.assertEqual(6, tracker.position_at_fraction(0.667))
   140  
   141    def test_get_fraction_consumed_dense(self):
   142      tracker = range_trackers.OffsetRangeTracker(3, 6)
   143      self.assertEqual(0, tracker.fraction_consumed())
   144      self.assertTrue(tracker.try_claim(3))
   145      self.assertEqual(0.0, tracker.fraction_consumed())
   146      self.assertTrue(tracker.try_claim(4))
   147      self.assertEqual(1.0 / 3, tracker.fraction_consumed())
   148      self.assertTrue(tracker.try_claim(5))
   149      self.assertEqual(2.0 / 3, tracker.fraction_consumed())
   150      tracker.set_current_position(6)
   151      self.assertEqual(1.0, tracker.fraction_consumed())
   152      tracker.set_current_position(7)
   153      self.assertFalse(tracker.try_claim(7))
   154  
   155    def test_get_fraction_consumed_sparse(self):
   156      tracker = range_trackers.OffsetRangeTracker(100, 200)
   157      self.assertEqual(0, tracker.fraction_consumed())
   158      self.assertTrue(tracker.try_claim(110))
   159      # Consumed positions through 110 = total 10 positions of 100 done.
   160      self.assertEqual(0.10, tracker.fraction_consumed())
   161      self.assertTrue(tracker.try_claim(150))
   162      self.assertEqual(0.50, tracker.fraction_consumed())
   163      self.assertTrue(tracker.try_claim(195))
   164      self.assertEqual(0.95, tracker.fraction_consumed())
   165  
   166    def test_everything_with_unbounded_range(self):
   167      tracker = range_trackers.OffsetRangeTracker(
   168          100, range_trackers.OffsetRangeTracker.OFFSET_INFINITY)
   169      self.assertTrue(tracker.try_claim(150))
   170      self.assertTrue(tracker.try_claim(250))
   171      # get_position_for_fraction_consumed should fail for an unbounded range
   172      with self.assertRaises(Exception):
   173        tracker.position_at_fraction(0.5)
   174  
   175    def test_try_return_first_record_not_split_point(self):
   176      with self.assertRaises(Exception):
   177        range_trackers.OffsetRangeTracker(100, 200).set_current_position(120)
   178  
   179    def test_try_return_record_non_monotonic(self):
   180      tracker = range_trackers.OffsetRangeTracker(100, 200)
   181      self.assertTrue(tracker.try_claim(120))
   182      with self.assertRaises(Exception):
   183        tracker.try_claim(110)
   184  
   185    def test_try_split_points(self):
   186      tracker = range_trackers.OffsetRangeTracker(100, 400)
   187  
   188      def dummy_callback(stop_position):
   189        return int(stop_position // 5)
   190  
   191      tracker.set_split_points_unclaimed_callback(dummy_callback)
   192  
   193      self.assertEqual(tracker.split_points(), (0, 81))
   194      self.assertTrue(tracker.try_claim(120))
   195      self.assertEqual(tracker.split_points(), (0, 81))
   196      self.assertTrue(tracker.try_claim(140))
   197      self.assertEqual(tracker.split_points(), (1, 81))
   198      tracker.try_split(200)
   199      self.assertEqual(tracker.split_points(), (1, 41))
   200      self.assertTrue(tracker.try_claim(150))
   201      self.assertEqual(tracker.split_points(), (2, 41))
   202      self.assertTrue(tracker.try_claim(180))
   203      self.assertEqual(tracker.split_points(), (3, 41))
   204      self.assertFalse(tracker.try_claim(210))
   205      self.assertEqual(tracker.split_points(), (3, 41))
   206  
   207  
   208  class OrderedPositionRangeTrackerTest(unittest.TestCase):
   209    class DoubleRangeTracker(range_trackers.OrderedPositionRangeTracker):
   210      @staticmethod
   211      def fraction_to_position(fraction, start, end):
   212        return start + (end - start) * fraction
   213  
   214      @staticmethod
   215      def position_to_fraction(pos, start, end):
   216        return float(pos - start) / (end - start)
   217  
   218    def test_try_claim(self):
   219      tracker = self.DoubleRangeTracker(10, 20)
   220      self.assertTrue(tracker.try_claim(10))
   221      self.assertTrue(tracker.try_claim(15))
   222      self.assertFalse(tracker.try_claim(20))
   223      self.assertFalse(tracker.try_claim(25))
   224  
   225    def test_fraction_consumed(self):
   226      tracker = self.DoubleRangeTracker(10, 20)
   227      self.assertEqual(0, tracker.fraction_consumed())
   228      tracker.try_claim(10)
   229      self.assertEqual(0, tracker.fraction_consumed())
   230      tracker.try_claim(15)
   231      self.assertEqual(.5, tracker.fraction_consumed())
   232      tracker.try_claim(17)
   233      self.assertEqual(.7, tracker.fraction_consumed())
   234      tracker.try_claim(25)
   235      self.assertEqual(.7, tracker.fraction_consumed())
   236  
   237    def test_try_split(self):
   238      tracker = self.DoubleRangeTracker(10, 20)
   239      tracker.try_claim(15)
   240      self.assertEqual(.5, tracker.fraction_consumed())
   241      # Split at 18.
   242      self.assertEqual((18, 0.8), tracker.try_split(18))
   243      # Fraction consumed reflects smaller range.
   244      self.assertEqual(.625, tracker.fraction_consumed())
   245      # We can claim anything less than 18,
   246      self.assertTrue(tracker.try_claim(17))
   247      # but can't split before claimed 17,
   248      self.assertIsNone(tracker.try_split(16))
   249      # nor claim anything at or after 18.
   250      self.assertFalse(tracker.try_claim(18))
   251      self.assertFalse(tracker.try_claim(19))
   252  
   253    def test_claim_order(self):
   254      tracker = self.DoubleRangeTracker(10, 20)
   255      tracker.try_claim(12)
   256      tracker.try_claim(15)
   257      with self.assertRaises(ValueError):
   258        tracker.try_claim(13)
   259  
   260    def test_out_of_range(self):
   261      tracker = self.DoubleRangeTracker(10, 20)
   262  
   263      # Can't claim before range.
   264      with self.assertRaises(ValueError):
   265        tracker.try_claim(-5)
   266  
   267      # Can't split before range.
   268      self.assertFalse(tracker.try_split(-5))
   269  
   270      # Reject useless split at start position.
   271      self.assertFalse(tracker.try_split(10))
   272  
   273      # Can't split after range.
   274      self.assertFalse(tracker.try_split(25))
   275      tracker.try_split(15)
   276  
   277      # Can't split after modified range.
   278      self.assertFalse(tracker.try_split(17))
   279  
   280      # Reject useless split at end position.
   281      self.assertFalse(tracker.try_split(15))
   282      self.assertTrue(tracker.try_split(14))
   283  
   284  
   285  class UnsplittableRangeTrackerTest(unittest.TestCase):
   286    def test_try_claim(self):
   287      tracker = range_trackers.UnsplittableRangeTracker(
   288          range_trackers.OffsetRangeTracker(100, 200))
   289      self.assertTrue(tracker.try_claim(110))
   290      self.assertTrue(tracker.try_claim(140))
   291      self.assertTrue(tracker.try_claim(183))
   292      self.assertFalse(tracker.try_claim(210))
   293  
   294    def test_try_split_fails(self):
   295      tracker = range_trackers.UnsplittableRangeTracker(
   296          range_trackers.OffsetRangeTracker(100, 200))
   297      self.assertTrue(tracker.try_claim(110))
   298      # Out of range
   299      self.assertFalse(tracker.try_split(109))
   300      self.assertFalse(tracker.try_split(210))
   301  
   302      # Within range. But splitting is still unsuccessful.
   303      self.assertFalse(copy.copy(tracker).try_split(111))
   304      self.assertFalse(copy.copy(tracker).try_split(130))
   305      self.assertFalse(copy.copy(tracker).try_split(199))
   306  
   307  
   308  class LexicographicKeyRangeTrackerTest(unittest.TestCase):
   309    """Tests of LexicographicKeyRangeTracker."""
   310  
   311    key_to_fraction = (
   312        range_trackers.LexicographicKeyRangeTracker.position_to_fraction)
   313    fraction_to_key = (
   314        range_trackers.LexicographicKeyRangeTracker.fraction_to_position)
   315  
   316    def _check(
   317        self,
   318        fraction: Optional[float] = None,
   319        key: Union[bytes, str] = None,
   320        start: Union[bytes, str] = None,
   321        end: Union[bytes, str] = None,
   322        delta: float = 0.0):
   323      assert key is not None or fraction is not None
   324      if fraction is None:
   325        fraction = self.key_to_fraction(key, start, end)
   326      elif key is None:
   327        key = self.fraction_to_key(fraction, start, end)
   328  
   329      if key is None and end is None and fraction == 1:
   330        # No way to distinguish from fraction == 0.
   331        computed_fraction = 1
   332      else:
   333        computed_fraction = self.key_to_fraction(key, start, end)
   334      computed_key = self.fraction_to_key(fraction, start, end)
   335  
   336      if delta:
   337        self.assertAlmostEqual(
   338            computed_fraction,
   339            fraction,
   340            delta=delta,
   341            places=None,
   342            msg=str(locals()))
   343      else:
   344        self.assertEqual(computed_fraction, fraction, str(locals()))
   345      self.assertEqual(computed_key, key, str(locals()))
   346  
   347    def test_key_to_fraction_no_endpoints(self):
   348      self._check(key=b'\x07', fraction=7 / 256.)
   349      self._check(key=b'\xFF', fraction=255 / 256.)
   350      self._check(key=b'\x01\x02\x03', fraction=(2**16 + 2**9 + 3) / (2.0**24))
   351      self._check(key=b'UUUUUUT', fraction=1 / 3)
   352      self._check(key=b'3333334', fraction=1 / 5)
   353      self._check(key=b'$\x92I$\x92I$', fraction=1 / 7, delta=1e-3)
   354      self._check(key=b'\x01\x02\x03', fraction=(2**16 + 2**9 + 3) / (2.0**24))
   355  
   356    def test_key_to_fraction(self):
   357      # test no key, no start
   358      self._check(end=b'eeeeee', fraction=0.0)
   359      self._check(end='eeeeee', fraction=0.0)
   360  
   361      # test no fraction
   362      self._check(key=b'bbbbbb', start=b'aaaaaa', end=b'eeeeee')
   363      self._check(key='bbbbbb', start='aaaaaa', end='eeeeee')
   364  
   365      # test no start
   366      self._check(key=b'eeeeee', end=b'eeeeee', fraction=1.0)
   367      self._check(key='eeeeee', end='eeeeee', fraction=1.0)
   368      self._check(key=b'\x19YYYYY@', end=b'eeeeee', fraction=0.25)
   369      self._check(key=b'2\xb2\xb2\xb2\xb2\xb2\x80', end='eeeeee', fraction=0.5)
   370      self._check(key=b'L\x0c\x0c\x0c\x0c\x0b\xc0', end=b'eeeeee', fraction=0.75)
   371  
   372      # test bytes keys
   373      self._check(key=b'\x87', start=b'\x80', fraction=7 / 128.)
   374      self._check(key=b'\x07', end=b'\x10', fraction=7 / 16.)
   375      self._check(key=b'\x47', start=b'\x40', end=b'\x80', fraction=7 / 64.)
   376      self._check(key=b'\x47\x80', start=b'\x40', end=b'\x80', fraction=15 / 128.)
   377  
   378      # test string keys
   379      self._check(key='aaaaaa', start='aaaaaa', end='eeeeee', fraction=0.0)
   380      self._check(key='bbbbbb', start='aaaaaa', end='eeeeee', fraction=0.25)
   381      self._check(key='cccccc', start='aaaaaa', end='eeeeee', fraction=0.5)
   382      self._check(key='dddddd', start='aaaaaa', end='eeeeee', fraction=0.75)
   383      self._check(key='eeeeee', start='aaaaaa', end='eeeeee', fraction=1.0)
   384  
   385    def test_key_to_fraction_common_prefix(self):
   386      # test bytes keys
   387      self._check(
   388          key=b'a' * 100 + b'b',
   389          start=b'a' * 100 + b'a',
   390          end=b'a' * 100 + b'c',
   391          fraction=0.5)
   392      self._check(
   393          key=b'a' * 100 + b'b',
   394          start=b'a' * 100 + b'a',
   395          end=b'a' * 100 + b'e',
   396          fraction=0.25)
   397      self._check(
   398          key=b'\xFF' * 100 + b'\x40',
   399          start=b'\xFF' * 100,
   400          end=None,
   401          fraction=0.25)
   402      self._check(
   403          key=b'foob',
   404          start=b'fooa\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFE',
   405          end=b'foob\x00\x00\x00\x00\x00\x00\x00\x00\x02',
   406          fraction=0.5)
   407  
   408      # test string keys
   409      self._check(
   410          key='a' * 100 + 'a',
   411          start='a' * 100 + 'a',
   412          end='a' * 100 + 'e',
   413          fraction=0.0)
   414      self._check(
   415          key='a' * 100 + 'b',
   416          start='a' * 100 + 'a',
   417          end='a' * 100 + 'e',
   418          fraction=0.25)
   419      self._check(
   420          key='a' * 100 + 'c',
   421          start='a' * 100 + 'a',
   422          end='a' * 100 + 'e',
   423          fraction=0.5)
   424      self._check(
   425          key='a' * 100 + 'd',
   426          start='a' * 100 + 'a',
   427          end='a' * 100 + 'e',
   428          fraction=0.75)
   429      self._check(
   430          key='a' * 100 + 'e',
   431          start='a' * 100 + 'a',
   432          end='a' * 100 + 'e',
   433          fraction=1.0)
   434  
   435    def test_tiny(self):
   436      # test bytes keys
   437      self._check(fraction=.5**20, key=b'\0\0\x10')
   438      self._check(fraction=.5**20, start=b'a', end=b'b', key=b'a\0\0\x10')
   439      self._check(fraction=.5**20, start=b'a', end=b'c', key=b'a\0\0\x20')
   440      self._check(
   441          fraction=.5**20, start=b'xy_a', end=b'xy_c', key=b'xy_a\0\0\x20')
   442      self._check(
   443          fraction=.5**20, start=b'\xFF\xFF\x80', key=b'\xFF\xFF\x80\x00\x08')
   444      self._check(
   445          fraction=.5**20 / 3,
   446          start=b'xy_a',
   447          end=b'xy_c',
   448          key=b'xy_a\x00\x00\n\xaa\xaa\xaa\xaa\xaa',
   449          delta=1e-15)
   450      self._check(fraction=.5**100, key=b'\0' * 12 + b'\x10')
   451  
   452      # test string keys
   453      self._check(fraction=.5**20, start='a', end='b', key='a\0\0\x10')
   454      self._check(fraction=.5**20, start='a', end='c', key='a\0\0\x20')
   455      self._check(fraction=.5**20, start='xy_a', end='xy_c', key='xy_a\0\0\x20')
   456  
   457    def test_lots(self):
   458      for fraction in (0, 1, .5, .75, 7. / 512, 1 - 7. / 4096):
   459        self._check(fraction)
   460        self._check(fraction, start=b'\x01')
   461        self._check(fraction, end=b'\xF0')
   462        self._check(fraction, start=b'0x75', end=b'\x76')
   463        self._check(fraction, start=b'0x75', end=b'\x77')
   464        self._check(fraction, start=b'0x75', end=b'\x78')
   465        self._check(
   466            fraction, start=b'a' * 100 + b'\x80', end=b'a' * 100 + b'\x81')
   467        self._check(
   468            fraction, start=b'a' * 101 + b'\x80', end=b'a' * 101 + b'\x81')
   469        self._check(
   470            fraction, start=b'a' * 102 + b'\x80', end=b'a' * 102 + b'\x81')
   471      for fraction in (.3, 1 / 3., 1 / math.e, .001, 1e-30, .99, .999999):
   472        self._check(fraction, delta=1e-14)
   473        self._check(fraction, start=b'\x01', delta=1e-14)
   474        self._check(fraction, end=b'\xF0', delta=1e-14)
   475        self._check(fraction, start=b'0x75', end=b'\x76', delta=1e-14)
   476        self._check(fraction, start=b'0x75', end=b'\x77', delta=1e-14)
   477        self._check(fraction, start=b'0x75', end=b'\x78', delta=1e-14)
   478        self._check(
   479            fraction,
   480            start=b'a' * 100 + b'\x80',
   481            end=b'a' * 100 + b'\x81',
   482            delta=1e-14)
   483  
   484    def test_good_prec(self):
   485      # There should be about 7 characters (~53 bits) of precision
   486      # (beyond the common prefix of start and end).
   487      self._check(
   488          1 / math.e,
   489          start='AAAAAAA',
   490          end='zzzzzzz',
   491          key='VNg/ot\x82',
   492          delta=1e-14)
   493      self._check(
   494          1 / math.e,
   495          start=b'abc_abc',
   496          end=b'abc_xyz',
   497          key=b'abc_i\xe0\xf4\x84\x86\x99\x96',
   498          delta=1e-15)
   499      # This remains true even if the start and end keys are given to
   500      # high precision.
   501      self._check(
   502          1 / math.e,
   503          start=b'abcd_abc\0\0\0\0\0_______________abc',
   504          end=b'abcd_xyz\0\0\0\0\0\0_______________abc',
   505          key=b'abcd_i\xe0\xf4\x84\x86\x99\x96',
   506          delta=1e-15)
   507      # For very small fractions, however, higher precision is used to
   508      # accurately represent small increments in the keyspace.
   509      self._check(
   510          1e-20 / math.e,
   511          start=b'abcd_abc',
   512          end=b'abcd_xyz',
   513          key=b'abcd_abc\x00\x00\x00\x00\x00\x01\x91#\x172N\xbb',
   514          delta=1e-35)
   515  
   516  
   517  if __name__ == '__main__':
   518    logging.getLogger().setLevel(logging.INFO)
   519    unittest.main()