github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/range_trackers_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Unit tests for the range_trackers module.""" 19 # pytype: skip-file 20 21 import copy 22 import logging 23 import math 24 import unittest 25 from typing import Optional 26 from typing import Union 27 28 from apache_beam.io import range_trackers 29 30 31 class OffsetRangeTrackerTest(unittest.TestCase): 32 def test_try_return_record_simple_sparse(self): 33 tracker = range_trackers.OffsetRangeTracker(100, 200) 34 self.assertTrue(tracker.try_claim(110)) 35 self.assertTrue(tracker.try_claim(140)) 36 self.assertTrue(tracker.try_claim(183)) 37 self.assertFalse(tracker.try_claim(210)) 38 39 def test_try_return_record_simple_dense(self): 40 tracker = range_trackers.OffsetRangeTracker(3, 6) 41 self.assertTrue(tracker.try_claim(3)) 42 self.assertTrue(tracker.try_claim(4)) 43 self.assertTrue(tracker.try_claim(5)) 44 self.assertFalse(tracker.try_claim(6)) 45 46 def test_try_claim_update_last_attempt(self): 47 tracker = range_trackers.OffsetRangeTracker(1, 2) 48 self.assertTrue(tracker.try_claim(1)) 49 self.assertEqual(1, tracker.last_attempted_record_start) 50 51 self.assertFalse(tracker.try_claim(3)) 52 self.assertEqual(3, tracker.last_attempted_record_start) 53 54 self.assertFalse(tracker.try_claim(6)) 55 self.assertEqual(6, tracker.last_attempted_record_start) 56 57 with self.assertRaises(Exception): 58 tracker.try_claim(6) 59 60 def test_set_current_position(self): 61 tracker = range_trackers.OffsetRangeTracker(0, 6) 62 self.assertTrue(tracker.try_claim(2)) 63 # Cannot set current position before successful claimed pos. 64 with self.assertRaises(Exception): 65 tracker.set_current_position(1) 66 67 self.assertFalse(tracker.try_claim(10)) 68 tracker.set_current_position(11) 69 self.assertEqual(10, tracker.last_attempted_record_start) 70 self.assertEqual(11, tracker.last_record_start) 71 72 def test_try_return_record_continuous_until_split_point(self): 73 tracker = range_trackers.OffsetRangeTracker(9, 18) 74 # Return records with gaps of 2; every 3rd record is a split point. 75 self.assertTrue(tracker.try_claim(10)) 76 tracker.set_current_position(12) 77 tracker.set_current_position(14) 78 self.assertTrue(tracker.try_claim(16)) 79 # Out of range, but not a split point... 80 tracker.set_current_position(18) 81 tracker.set_current_position(20) 82 # Out of range AND a split point. 83 self.assertFalse(tracker.try_claim(22)) 84 85 def test_split_at_offset_fails_if_unstarted(self): 86 tracker = range_trackers.OffsetRangeTracker(100, 200) 87 self.assertFalse(tracker.try_split(150)) 88 89 def test_split_at_offset(self): 90 tracker = range_trackers.OffsetRangeTracker(100, 200) 91 self.assertTrue(tracker.try_claim(110)) 92 # Example positions we shouldn't split at, when last record starts at 110: 93 self.assertFalse(tracker.try_split(109)) 94 self.assertFalse(tracker.try_split(110)) 95 self.assertFalse(tracker.try_split(200)) 96 self.assertFalse(tracker.try_split(210)) 97 # Example positions we *should* split at: 98 self.assertTrue(copy.copy(tracker).try_split(111)) 99 self.assertTrue(copy.copy(tracker).try_split(129)) 100 self.assertTrue(copy.copy(tracker).try_split(130)) 101 self.assertTrue(copy.copy(tracker).try_split(131)) 102 self.assertTrue(copy.copy(tracker).try_split(150)) 103 self.assertTrue(copy.copy(tracker).try_split(199)) 104 105 # If we split at 170 and then at 150: 106 self.assertTrue(tracker.try_split(170)) 107 self.assertTrue(tracker.try_split(150)) 108 # Should be able to return a record starting before the new stop offset. 109 # Returning records starting at the same offset is ok. 110 self.assertTrue(copy.copy(tracker).try_claim(135)) 111 self.assertTrue(copy.copy(tracker).try_claim(135)) 112 # Should be able to return a record starting right before the new stop 113 # offset. 114 self.assertTrue(copy.copy(tracker).try_claim(149)) 115 # Should not be able to return a record starting at or after the new stop 116 # offset. 117 self.assertFalse(tracker.try_claim(150)) 118 self.assertFalse(tracker.try_claim(151)) 119 # Should accept non-splitpoint records starting after stop offset. 120 tracker.set_current_position(152) 121 tracker.set_current_position(160) 122 tracker.set_current_position(171) 123 124 def test_get_position_for_fraction_dense(self): 125 # Represents positions 3, 4, 5. 126 tracker = range_trackers.OffsetRangeTracker(3, 6) 127 128 # Position must be an integer type. 129 self.assertTrue(isinstance(tracker.position_at_fraction(0.0), int)) 130 # [3, 3) represents 0.0 of [3, 6) 131 self.assertEqual(3, tracker.position_at_fraction(0.0)) 132 # [3, 4) represents up to 1/3 of [3, 6) 133 self.assertEqual(4, tracker.position_at_fraction(1.0 / 6)) 134 self.assertEqual(4, tracker.position_at_fraction(0.333)) 135 # [3, 5) represents up to 2/3 of [3, 6) 136 self.assertEqual(5, tracker.position_at_fraction(0.334)) 137 self.assertEqual(5, tracker.position_at_fraction(0.666)) 138 # Any fraction consumed over 2/3 means the whole [3, 6) has been consumed. 139 self.assertEqual(6, tracker.position_at_fraction(0.667)) 140 141 def test_get_fraction_consumed_dense(self): 142 tracker = range_trackers.OffsetRangeTracker(3, 6) 143 self.assertEqual(0, tracker.fraction_consumed()) 144 self.assertTrue(tracker.try_claim(3)) 145 self.assertEqual(0.0, tracker.fraction_consumed()) 146 self.assertTrue(tracker.try_claim(4)) 147 self.assertEqual(1.0 / 3, tracker.fraction_consumed()) 148 self.assertTrue(tracker.try_claim(5)) 149 self.assertEqual(2.0 / 3, tracker.fraction_consumed()) 150 tracker.set_current_position(6) 151 self.assertEqual(1.0, tracker.fraction_consumed()) 152 tracker.set_current_position(7) 153 self.assertFalse(tracker.try_claim(7)) 154 155 def test_get_fraction_consumed_sparse(self): 156 tracker = range_trackers.OffsetRangeTracker(100, 200) 157 self.assertEqual(0, tracker.fraction_consumed()) 158 self.assertTrue(tracker.try_claim(110)) 159 # Consumed positions through 110 = total 10 positions of 100 done. 160 self.assertEqual(0.10, tracker.fraction_consumed()) 161 self.assertTrue(tracker.try_claim(150)) 162 self.assertEqual(0.50, tracker.fraction_consumed()) 163 self.assertTrue(tracker.try_claim(195)) 164 self.assertEqual(0.95, tracker.fraction_consumed()) 165 166 def test_everything_with_unbounded_range(self): 167 tracker = range_trackers.OffsetRangeTracker( 168 100, range_trackers.OffsetRangeTracker.OFFSET_INFINITY) 169 self.assertTrue(tracker.try_claim(150)) 170 self.assertTrue(tracker.try_claim(250)) 171 # get_position_for_fraction_consumed should fail for an unbounded range 172 with self.assertRaises(Exception): 173 tracker.position_at_fraction(0.5) 174 175 def test_try_return_first_record_not_split_point(self): 176 with self.assertRaises(Exception): 177 range_trackers.OffsetRangeTracker(100, 200).set_current_position(120) 178 179 def test_try_return_record_non_monotonic(self): 180 tracker = range_trackers.OffsetRangeTracker(100, 200) 181 self.assertTrue(tracker.try_claim(120)) 182 with self.assertRaises(Exception): 183 tracker.try_claim(110) 184 185 def test_try_split_points(self): 186 tracker = range_trackers.OffsetRangeTracker(100, 400) 187 188 def dummy_callback(stop_position): 189 return int(stop_position // 5) 190 191 tracker.set_split_points_unclaimed_callback(dummy_callback) 192 193 self.assertEqual(tracker.split_points(), (0, 81)) 194 self.assertTrue(tracker.try_claim(120)) 195 self.assertEqual(tracker.split_points(), (0, 81)) 196 self.assertTrue(tracker.try_claim(140)) 197 self.assertEqual(tracker.split_points(), (1, 81)) 198 tracker.try_split(200) 199 self.assertEqual(tracker.split_points(), (1, 41)) 200 self.assertTrue(tracker.try_claim(150)) 201 self.assertEqual(tracker.split_points(), (2, 41)) 202 self.assertTrue(tracker.try_claim(180)) 203 self.assertEqual(tracker.split_points(), (3, 41)) 204 self.assertFalse(tracker.try_claim(210)) 205 self.assertEqual(tracker.split_points(), (3, 41)) 206 207 208 class OrderedPositionRangeTrackerTest(unittest.TestCase): 209 class DoubleRangeTracker(range_trackers.OrderedPositionRangeTracker): 210 @staticmethod 211 def fraction_to_position(fraction, start, end): 212 return start + (end - start) * fraction 213 214 @staticmethod 215 def position_to_fraction(pos, start, end): 216 return float(pos - start) / (end - start) 217 218 def test_try_claim(self): 219 tracker = self.DoubleRangeTracker(10, 20) 220 self.assertTrue(tracker.try_claim(10)) 221 self.assertTrue(tracker.try_claim(15)) 222 self.assertFalse(tracker.try_claim(20)) 223 self.assertFalse(tracker.try_claim(25)) 224 225 def test_fraction_consumed(self): 226 tracker = self.DoubleRangeTracker(10, 20) 227 self.assertEqual(0, tracker.fraction_consumed()) 228 tracker.try_claim(10) 229 self.assertEqual(0, tracker.fraction_consumed()) 230 tracker.try_claim(15) 231 self.assertEqual(.5, tracker.fraction_consumed()) 232 tracker.try_claim(17) 233 self.assertEqual(.7, tracker.fraction_consumed()) 234 tracker.try_claim(25) 235 self.assertEqual(.7, tracker.fraction_consumed()) 236 237 def test_try_split(self): 238 tracker = self.DoubleRangeTracker(10, 20) 239 tracker.try_claim(15) 240 self.assertEqual(.5, tracker.fraction_consumed()) 241 # Split at 18. 242 self.assertEqual((18, 0.8), tracker.try_split(18)) 243 # Fraction consumed reflects smaller range. 244 self.assertEqual(.625, tracker.fraction_consumed()) 245 # We can claim anything less than 18, 246 self.assertTrue(tracker.try_claim(17)) 247 # but can't split before claimed 17, 248 self.assertIsNone(tracker.try_split(16)) 249 # nor claim anything at or after 18. 250 self.assertFalse(tracker.try_claim(18)) 251 self.assertFalse(tracker.try_claim(19)) 252 253 def test_claim_order(self): 254 tracker = self.DoubleRangeTracker(10, 20) 255 tracker.try_claim(12) 256 tracker.try_claim(15) 257 with self.assertRaises(ValueError): 258 tracker.try_claim(13) 259 260 def test_out_of_range(self): 261 tracker = self.DoubleRangeTracker(10, 20) 262 263 # Can't claim before range. 264 with self.assertRaises(ValueError): 265 tracker.try_claim(-5) 266 267 # Can't split before range. 268 self.assertFalse(tracker.try_split(-5)) 269 270 # Reject useless split at start position. 271 self.assertFalse(tracker.try_split(10)) 272 273 # Can't split after range. 274 self.assertFalse(tracker.try_split(25)) 275 tracker.try_split(15) 276 277 # Can't split after modified range. 278 self.assertFalse(tracker.try_split(17)) 279 280 # Reject useless split at end position. 281 self.assertFalse(tracker.try_split(15)) 282 self.assertTrue(tracker.try_split(14)) 283 284 285 class UnsplittableRangeTrackerTest(unittest.TestCase): 286 def test_try_claim(self): 287 tracker = range_trackers.UnsplittableRangeTracker( 288 range_trackers.OffsetRangeTracker(100, 200)) 289 self.assertTrue(tracker.try_claim(110)) 290 self.assertTrue(tracker.try_claim(140)) 291 self.assertTrue(tracker.try_claim(183)) 292 self.assertFalse(tracker.try_claim(210)) 293 294 def test_try_split_fails(self): 295 tracker = range_trackers.UnsplittableRangeTracker( 296 range_trackers.OffsetRangeTracker(100, 200)) 297 self.assertTrue(tracker.try_claim(110)) 298 # Out of range 299 self.assertFalse(tracker.try_split(109)) 300 self.assertFalse(tracker.try_split(210)) 301 302 # Within range. But splitting is still unsuccessful. 303 self.assertFalse(copy.copy(tracker).try_split(111)) 304 self.assertFalse(copy.copy(tracker).try_split(130)) 305 self.assertFalse(copy.copy(tracker).try_split(199)) 306 307 308 class LexicographicKeyRangeTrackerTest(unittest.TestCase): 309 """Tests of LexicographicKeyRangeTracker.""" 310 311 key_to_fraction = ( 312 range_trackers.LexicographicKeyRangeTracker.position_to_fraction) 313 fraction_to_key = ( 314 range_trackers.LexicographicKeyRangeTracker.fraction_to_position) 315 316 def _check( 317 self, 318 fraction: Optional[float] = None, 319 key: Union[bytes, str] = None, 320 start: Union[bytes, str] = None, 321 end: Union[bytes, str] = None, 322 delta: float = 0.0): 323 assert key is not None or fraction is not None 324 if fraction is None: 325 fraction = self.key_to_fraction(key, start, end) 326 elif key is None: 327 key = self.fraction_to_key(fraction, start, end) 328 329 if key is None and end is None and fraction == 1: 330 # No way to distinguish from fraction == 0. 331 computed_fraction = 1 332 else: 333 computed_fraction = self.key_to_fraction(key, start, end) 334 computed_key = self.fraction_to_key(fraction, start, end) 335 336 if delta: 337 self.assertAlmostEqual( 338 computed_fraction, 339 fraction, 340 delta=delta, 341 places=None, 342 msg=str(locals())) 343 else: 344 self.assertEqual(computed_fraction, fraction, str(locals())) 345 self.assertEqual(computed_key, key, str(locals())) 346 347 def test_key_to_fraction_no_endpoints(self): 348 self._check(key=b'\x07', fraction=7 / 256.) 349 self._check(key=b'\xFF', fraction=255 / 256.) 350 self._check(key=b'\x01\x02\x03', fraction=(2**16 + 2**9 + 3) / (2.0**24)) 351 self._check(key=b'UUUUUUT', fraction=1 / 3) 352 self._check(key=b'3333334', fraction=1 / 5) 353 self._check(key=b'$\x92I$\x92I$', fraction=1 / 7, delta=1e-3) 354 self._check(key=b'\x01\x02\x03', fraction=(2**16 + 2**9 + 3) / (2.0**24)) 355 356 def test_key_to_fraction(self): 357 # test no key, no start 358 self._check(end=b'eeeeee', fraction=0.0) 359 self._check(end='eeeeee', fraction=0.0) 360 361 # test no fraction 362 self._check(key=b'bbbbbb', start=b'aaaaaa', end=b'eeeeee') 363 self._check(key='bbbbbb', start='aaaaaa', end='eeeeee') 364 365 # test no start 366 self._check(key=b'eeeeee', end=b'eeeeee', fraction=1.0) 367 self._check(key='eeeeee', end='eeeeee', fraction=1.0) 368 self._check(key=b'\x19YYYYY@', end=b'eeeeee', fraction=0.25) 369 self._check(key=b'2\xb2\xb2\xb2\xb2\xb2\x80', end='eeeeee', fraction=0.5) 370 self._check(key=b'L\x0c\x0c\x0c\x0c\x0b\xc0', end=b'eeeeee', fraction=0.75) 371 372 # test bytes keys 373 self._check(key=b'\x87', start=b'\x80', fraction=7 / 128.) 374 self._check(key=b'\x07', end=b'\x10', fraction=7 / 16.) 375 self._check(key=b'\x47', start=b'\x40', end=b'\x80', fraction=7 / 64.) 376 self._check(key=b'\x47\x80', start=b'\x40', end=b'\x80', fraction=15 / 128.) 377 378 # test string keys 379 self._check(key='aaaaaa', start='aaaaaa', end='eeeeee', fraction=0.0) 380 self._check(key='bbbbbb', start='aaaaaa', end='eeeeee', fraction=0.25) 381 self._check(key='cccccc', start='aaaaaa', end='eeeeee', fraction=0.5) 382 self._check(key='dddddd', start='aaaaaa', end='eeeeee', fraction=0.75) 383 self._check(key='eeeeee', start='aaaaaa', end='eeeeee', fraction=1.0) 384 385 def test_key_to_fraction_common_prefix(self): 386 # test bytes keys 387 self._check( 388 key=b'a' * 100 + b'b', 389 start=b'a' * 100 + b'a', 390 end=b'a' * 100 + b'c', 391 fraction=0.5) 392 self._check( 393 key=b'a' * 100 + b'b', 394 start=b'a' * 100 + b'a', 395 end=b'a' * 100 + b'e', 396 fraction=0.25) 397 self._check( 398 key=b'\xFF' * 100 + b'\x40', 399 start=b'\xFF' * 100, 400 end=None, 401 fraction=0.25) 402 self._check( 403 key=b'foob', 404 start=b'fooa\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFE', 405 end=b'foob\x00\x00\x00\x00\x00\x00\x00\x00\x02', 406 fraction=0.5) 407 408 # test string keys 409 self._check( 410 key='a' * 100 + 'a', 411 start='a' * 100 + 'a', 412 end='a' * 100 + 'e', 413 fraction=0.0) 414 self._check( 415 key='a' * 100 + 'b', 416 start='a' * 100 + 'a', 417 end='a' * 100 + 'e', 418 fraction=0.25) 419 self._check( 420 key='a' * 100 + 'c', 421 start='a' * 100 + 'a', 422 end='a' * 100 + 'e', 423 fraction=0.5) 424 self._check( 425 key='a' * 100 + 'd', 426 start='a' * 100 + 'a', 427 end='a' * 100 + 'e', 428 fraction=0.75) 429 self._check( 430 key='a' * 100 + 'e', 431 start='a' * 100 + 'a', 432 end='a' * 100 + 'e', 433 fraction=1.0) 434 435 def test_tiny(self): 436 # test bytes keys 437 self._check(fraction=.5**20, key=b'\0\0\x10') 438 self._check(fraction=.5**20, start=b'a', end=b'b', key=b'a\0\0\x10') 439 self._check(fraction=.5**20, start=b'a', end=b'c', key=b'a\0\0\x20') 440 self._check( 441 fraction=.5**20, start=b'xy_a', end=b'xy_c', key=b'xy_a\0\0\x20') 442 self._check( 443 fraction=.5**20, start=b'\xFF\xFF\x80', key=b'\xFF\xFF\x80\x00\x08') 444 self._check( 445 fraction=.5**20 / 3, 446 start=b'xy_a', 447 end=b'xy_c', 448 key=b'xy_a\x00\x00\n\xaa\xaa\xaa\xaa\xaa', 449 delta=1e-15) 450 self._check(fraction=.5**100, key=b'\0' * 12 + b'\x10') 451 452 # test string keys 453 self._check(fraction=.5**20, start='a', end='b', key='a\0\0\x10') 454 self._check(fraction=.5**20, start='a', end='c', key='a\0\0\x20') 455 self._check(fraction=.5**20, start='xy_a', end='xy_c', key='xy_a\0\0\x20') 456 457 def test_lots(self): 458 for fraction in (0, 1, .5, .75, 7. / 512, 1 - 7. / 4096): 459 self._check(fraction) 460 self._check(fraction, start=b'\x01') 461 self._check(fraction, end=b'\xF0') 462 self._check(fraction, start=b'0x75', end=b'\x76') 463 self._check(fraction, start=b'0x75', end=b'\x77') 464 self._check(fraction, start=b'0x75', end=b'\x78') 465 self._check( 466 fraction, start=b'a' * 100 + b'\x80', end=b'a' * 100 + b'\x81') 467 self._check( 468 fraction, start=b'a' * 101 + b'\x80', end=b'a' * 101 + b'\x81') 469 self._check( 470 fraction, start=b'a' * 102 + b'\x80', end=b'a' * 102 + b'\x81') 471 for fraction in (.3, 1 / 3., 1 / math.e, .001, 1e-30, .99, .999999): 472 self._check(fraction, delta=1e-14) 473 self._check(fraction, start=b'\x01', delta=1e-14) 474 self._check(fraction, end=b'\xF0', delta=1e-14) 475 self._check(fraction, start=b'0x75', end=b'\x76', delta=1e-14) 476 self._check(fraction, start=b'0x75', end=b'\x77', delta=1e-14) 477 self._check(fraction, start=b'0x75', end=b'\x78', delta=1e-14) 478 self._check( 479 fraction, 480 start=b'a' * 100 + b'\x80', 481 end=b'a' * 100 + b'\x81', 482 delta=1e-14) 483 484 def test_good_prec(self): 485 # There should be about 7 characters (~53 bits) of precision 486 # (beyond the common prefix of start and end). 487 self._check( 488 1 / math.e, 489 start='AAAAAAA', 490 end='zzzzzzz', 491 key='VNg/ot\x82', 492 delta=1e-14) 493 self._check( 494 1 / math.e, 495 start=b'abc_abc', 496 end=b'abc_xyz', 497 key=b'abc_i\xe0\xf4\x84\x86\x99\x96', 498 delta=1e-15) 499 # This remains true even if the start and end keys are given to 500 # high precision. 501 self._check( 502 1 / math.e, 503 start=b'abcd_abc\0\0\0\0\0_______________abc', 504 end=b'abcd_xyz\0\0\0\0\0\0_______________abc', 505 key=b'abcd_i\xe0\xf4\x84\x86\x99\x96', 506 delta=1e-15) 507 # For very small fractions, however, higher precision is used to 508 # accurately represent small increments in the keyspace. 509 self._check( 510 1e-20 / math.e, 511 start=b'abcd_abc', 512 end=b'abcd_xyz', 513 key=b'abcd_abc\x00\x00\x00\x00\x00\x01\x91#\x172N\xbb', 514 delta=1e-35) 515 516 517 if __name__ == '__main__': 518 logging.getLogger().setLevel(logging.INFO) 519 unittest.main()