github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/mongodbio_test.py (about) 1 # Licensed to the Apache Software Foundation (ASF) under one or more 2 # contributor license agreements. See the NOTICE file distributed with 3 # this work for additional information regarding copyright ownership. 4 # The ASF licenses this file to You under the Apache License, Version 2.0 5 # (the "License"); you may not use this file except in compliance with 6 # the License. You may obtain a copy of the License at 7 # 8 # http://www.apache.org/licenses/LICENSE-2.0 9 # 10 # Unless required by applicable law or agreed to in writing, software 11 # distributed under the License is distributed on an "AS IS" BASIS, 12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 # See the License for the specific language governing permissions and 14 # limitations under the License. 15 # 16 17 # pytype: skip-file 18 19 import datetime 20 import logging 21 import random 22 import unittest 23 from typing import Union 24 from unittest import TestCase 25 26 import mock 27 from bson import ObjectId 28 from bson import objectid 29 from parameterized import parameterized_class 30 from pymongo import ASCENDING 31 from pymongo import ReplaceOne 32 33 import apache_beam as beam 34 from apache_beam.io import ReadFromMongoDB 35 from apache_beam.io import WriteToMongoDB 36 from apache_beam.io import source_test_utils 37 from apache_beam.io.mongodbio import _BoundedMongoSource 38 from apache_beam.io.mongodbio import _GenerateObjectIdFn 39 from apache_beam.io.mongodbio import _MongoSink 40 from apache_beam.io.mongodbio import _ObjectIdHelper 41 from apache_beam.io.mongodbio import _ObjectIdRangeTracker 42 from apache_beam.io.mongodbio import _WriteMongoFn 43 from apache_beam.io.range_trackers import LexicographicKeyRangeTracker 44 from apache_beam.io.range_trackers import OffsetRangeTracker 45 from apache_beam.testing.test_pipeline import TestPipeline 46 from apache_beam.testing.util import assert_that 47 from apache_beam.testing.util import equal_to 48 49 50 class _MockMongoColl(object): 51 """Fake mongodb collection cursor.""" 52 def __init__(self, docs): 53 self.docs = docs 54 55 def __getitem__(self, index): 56 return self.docs[index] 57 58 def __len__(self): 59 return len(self.docs) 60 61 @staticmethod 62 def _make_filter(conditions): 63 assert isinstance(conditions, dict) 64 checks = [] 65 for field, value in conditions.items(): 66 if isinstance(value, dict): 67 for op, val in value.items(): 68 if op == '$gte': 69 op = '__ge__' 70 elif op == '$lt': 71 op = '__lt__' 72 else: 73 raise Exception('Operator "{0}" not supported.'.format(op)) 74 checks.append((field, op, val)) 75 else: 76 checks.append((field, '__eq__', value)) 77 78 def func(doc): 79 for field, op, value in checks: 80 if not getattr(doc[field], op)(value): 81 return False 82 return True 83 84 return func 85 86 def _filter(self, filter): 87 match = [] 88 if not filter: 89 return self 90 all_filters = [] 91 if '$and' in filter: 92 for item in filter['$and']: 93 all_filters.append(self._make_filter(item)) 94 else: 95 all_filters.append(self._make_filter(filter)) 96 97 for doc in self.docs: 98 if not all(check(doc) for check in all_filters): 99 continue 100 match.append(doc) 101 102 return match 103 104 @staticmethod 105 def _projection(docs, projection=None): 106 if projection: 107 return [{k: v 108 for k, v in doc.items() if k in projection or k == '_id'} 109 for doc in docs] 110 return docs 111 112 def find(self, filter=None, projection=None, **kwargs): 113 return _MockMongoColl(self._projection(self._filter(filter), projection)) 114 115 def sort(self, sort_items): 116 key, order = sort_items[0] 117 self.docs = sorted( 118 self.docs, key=lambda x: x[key], reverse=(order != ASCENDING)) 119 return self 120 121 def limit(self, num): 122 return _MockMongoColl(self.docs[0:num]) 123 124 def count_documents(self, filter): 125 return len(self._filter(filter)) 126 127 def aggregate(self, pipeline, **kwargs): 128 # Simulate $bucketAuto aggregate pipeline. 129 # Example splits doc count for the total of 5 docs: 130 # - 1 bucket: [5] 131 # - 2 buckets: [3, 2] 132 # - 3 buckets: [2, 2, 1] 133 # - 4 buckets: [2, 1, 1, 1] 134 # - 5 buckets: [1, 1, 1, 1, 1] 135 match_step = next((step for step in pipeline if '$match' in step), None) 136 bucket_auto_step = next(step for step in pipeline if '$bucketAuto' in step) 137 if match_step is None: 138 docs = self.docs 139 else: 140 docs = self.find(filter=match_step['$match']) 141 doc_count = len(docs) 142 bucket_count = min(bucket_auto_step['$bucketAuto']['buckets'], doc_count) 143 # bucket_count ≠ 0 144 bucket_len, remainder = divmod(doc_count, bucket_count) 145 bucket_sizes = ( 146 remainder * [bucket_len + 1] + 147 (bucket_count - remainder) * [bucket_len]) 148 buckets = [] 149 start = 0 150 for bucket_size in bucket_sizes: 151 stop = start + bucket_size 152 if stop >= doc_count: 153 # MongoDB: the last bucket's 'max' is inclusive 154 stop = doc_count - 1 155 count = stop - start + 1 156 else: 157 # non-last bucket's 'max' is exclusive and == next bucket's 'min' 158 count = stop - start 159 buckets.append({ 160 '_id': { 161 'min': docs[start]['_id'], 162 'max': docs[stop]['_id'], 163 }, 164 'count': count 165 }) 166 start = stop 167 168 return buckets 169 170 171 class _MockMongoDb(object): 172 """Fake Mongo Db.""" 173 def __init__(self, docs): 174 self.docs = docs 175 176 def __getitem__(self, coll_name): 177 return _MockMongoColl(self.docs) 178 179 def command(self, command, *args, **kwargs): 180 if command == 'collstats': 181 return {'size': 5 * 1024 * 1024, 'avgObjSize': 1 * 1024 * 1024} 182 if command == 'splitVector': 183 return self.get_split_keys(command, *args, **kwargs) 184 185 def get_split_keys(self, command, ns, min, max, maxChunkSize, **kwargs): 186 # simulate mongo db splitVector command, return split keys base on chunk 187 # size, assuming every doc is of size 1mb 188 start_id = min['_id'] 189 end_id = max['_id'] 190 if start_id >= end_id: 191 return [] 192 start_index = 0 193 end_index = 0 194 # get split range of [min, max] 195 for doc in self.docs: 196 if doc['_id'] < start_id: 197 start_index += 1 198 if doc['_id'] <= end_id: 199 end_index += 1 200 else: 201 break 202 # Return ids of elements in the range with chunk size skip and exclude 203 # head element. For simplicity of tests every document is considered 1Mb 204 # by default. 205 return { 206 'splitKeys': [{ 207 '_id': x['_id'] 208 } for x in self.docs[start_index:end_index:maxChunkSize]][1:] 209 } 210 211 212 class _MockMongoClient: 213 def __init__(self, docs): 214 self.docs = docs 215 216 def __getitem__(self, db_name): 217 return _MockMongoDb(self.docs) 218 219 def __enter__(self): 220 return self 221 222 def __exit__(self, exc_type, exc_val, exc_tb): 223 pass 224 225 226 # Generate test data for MongoDB collections of different types 227 OBJECT_IDS = [ 228 objectid.ObjectId.from_datetime( 229 datetime.datetime(year=2020, month=i + 1, day=i + 1)) for i in range(5) 230 ] 231 232 INT_IDS = [n for n in range(5)] # [0, 1, 2, 3, 4] 233 234 STR_IDS_1 = [str(n) for n in range(5)] # ['0', '1', '2', '3', '4'] 235 236 # ['aaaaa', 'bbbbb', 'ccccc', 'ddddd', 'eeeee'] 237 STR_IDS_2 = [chr(97 + n) * 5 for n in range(5)] 238 239 # ['AAAAAAAAAAAAAAAAAAAA', 'BBBBBBBBBBBBBBBBBBBB', ..., 'EEEEEEEEEEEEEEEEEEEE'] 240 STR_IDS_3 = [chr(65 + n) * 20 for n in range(5)] 241 242 243 @parameterized_class(('bucket_auto', '_ids', 'min_id', 'max_id'), 244 [ 245 ( 246 None, 247 OBJECT_IDS, 248 _ObjectIdHelper.int_to_id(0), 249 _ObjectIdHelper.int_to_id(2**96 - 1)), 250 ( 251 True, 252 OBJECT_IDS, 253 _ObjectIdHelper.int_to_id(0), 254 _ObjectIdHelper.int_to_id(2**96 - 1)), 255 ( 256 None, 257 INT_IDS, 258 0, 259 2**96 - 1, 260 ), 261 ( 262 True, 263 INT_IDS, 264 0, 265 2**96 - 1, 266 ), 267 ( 268 None, 269 STR_IDS_1, 270 chr(0), 271 chr(0x10ffff), 272 ), 273 ( 274 True, 275 STR_IDS_1, 276 chr(0), 277 chr(0x10ffff), 278 ), 279 ( 280 None, 281 STR_IDS_2, 282 chr(0), 283 chr(0x10ffff), 284 ), 285 ( 286 True, 287 STR_IDS_2, 288 chr(0), 289 chr(0x10ffff), 290 ), 291 ( 292 None, 293 STR_IDS_3, 294 chr(0), 295 chr(0x10ffff), 296 ), 297 ( 298 True, 299 STR_IDS_3, 300 chr(0), 301 chr(0x10ffff), 302 ), 303 ]) 304 class MongoSourceTest(unittest.TestCase): 305 @mock.patch('apache_beam.io.mongodbio.MongoClient') 306 def setUp(self, mock_client): 307 self._docs = [{'_id': self._ids[i], 'x': i} for i in range(len(self._ids))] 308 mock_client.return_value = _MockMongoClient(self._docs) 309 310 self.mongo_source = self._create_source(bucket_auto=self.bucket_auto) 311 312 @staticmethod 313 def _create_source(filter=None, bucket_auto=None): 314 kwargs = {} 315 if filter is not None: 316 kwargs['filter'] = filter 317 if bucket_auto is not None: 318 kwargs['bucket_auto'] = bucket_auto 319 return _BoundedMongoSource('mongodb://test', 'testdb', 'testcoll', **kwargs) 320 321 def _increment_id( 322 self, 323 _id: Union[ObjectId, int, str], 324 inc: int, 325 ) -> Union[ObjectId, int, str]: 326 """Helper method to increment `_id` of different types.""" 327 328 if isinstance(_id, ObjectId): 329 return _ObjectIdHelper.increment_id(_id, inc) 330 331 if isinstance(_id, int): 332 return _id + inc 333 334 if isinstance(_id, str): 335 index = self._ids.index(_id) + inc 336 if index <= 0: 337 return self._ids[0] 338 if index >= len(self._ids): 339 return self._ids[-1] 340 return self._ids[index] 341 342 @mock.patch('apache_beam.io.mongodbio.MongoClient') 343 def test_estimate_size(self, mock_client): 344 mock_client.return_value = _MockMongoClient(self._docs) 345 self.assertEqual(self.mongo_source.estimate_size(), 5 * 1024 * 1024) 346 347 @mock.patch('apache_beam.io.mongodbio.MongoClient') 348 def test_estimate_average_document_size(self, mock_client): 349 mock_client.return_value = _MockMongoClient(self._docs) 350 self.assertEqual( 351 self.mongo_source._estimate_average_document_size(), 1 * 1024 * 1024) 352 353 @mock.patch('apache_beam.io.mongodbio.MongoClient') 354 def test_split(self, mock_client): 355 mock_client.return_value = _MockMongoClient(self._docs) 356 for size_mb, expected_split_count in [(0.5, 5), (1, 5), (2, 3), (10, 1)]: 357 size = size_mb * 1024 * 1024 358 splits = list( 359 self.mongo_source.split( 360 start_position=None, stop_position=None, 361 desired_bundle_size=size)) 362 363 self.assertEqual(len(splits), expected_split_count) 364 reference_info = (self.mongo_source, None, None) 365 sources_info = ([ 366 (split.source, split.start_position, split.stop_position) 367 for split in splits 368 ]) 369 source_test_utils.assert_sources_equal_reference_source( 370 reference_info, sources_info) 371 372 @mock.patch('apache_beam.io.mongodbio.MongoClient') 373 def test_split_single_document(self, mock_client): 374 mock_client.return_value = _MockMongoClient(self._docs[0:1]) 375 for size_mb in [1, 5]: 376 size = size_mb * 1024 * 1024 377 splits = list( 378 self.mongo_source.split( 379 start_position=None, stop_position=None, 380 desired_bundle_size=size)) 381 self.assertEqual(len(splits), 1) 382 _id = self._docs[0]['_id'] 383 assert _id == splits[0].start_position 384 assert _id <= splits[0].stop_position 385 if isinstance(_id, (ObjectId, int)): 386 # We can unambiguously determine next `_id` 387 assert self._increment_id(_id, 1) == splits[0].stop_position 388 389 @mock.patch('apache_beam.io.mongodbio.MongoClient') 390 def test_split_no_documents(self, mock_client): 391 mock_client.return_value = _MockMongoClient([]) 392 with self.assertRaises(ValueError) as cm: 393 list( 394 self.mongo_source.split( 395 start_position=None, 396 stop_position=None, 397 desired_bundle_size=1024 * 1024)) 398 self.assertEqual(str(cm.exception), 'Empty Mongodb collection') 399 400 @mock.patch('apache_beam.io.mongodbio.MongoClient') 401 def test_split_filtered(self, mock_client): 402 # filtering 2 documents: 2 <= 'x' < 4 403 filtered_mongo_source = self._create_source( 404 filter={'x': { 405 '$gte': 2, '$lt': 4 406 }}, bucket_auto=self.bucket_auto) 407 408 mock_client.return_value = _MockMongoClient(self._docs) 409 for size_mb, (bucket_auto_count, split_vector_count) in [(1, (2, 5)), 410 (2, (1, 3)), 411 (10, (1, 1))]: 412 size = size_mb * 1024 * 1024 413 splits = list( 414 filtered_mongo_source.split( 415 start_position=None, stop_position=None, 416 desired_bundle_size=size)) 417 418 if self.bucket_auto: 419 self.assertEqual(len(splits), bucket_auto_count) 420 else: 421 # Note: splitVector mode does not respect filter 422 self.assertEqual(len(splits), split_vector_count) 423 reference_info = ( 424 filtered_mongo_source, self._docs[2]['_id'], self._docs[4]['_id']) 425 sources_info = ([ 426 (split.source, split.start_position, split.stop_position) 427 for split in splits 428 ]) 429 source_test_utils.assert_sources_equal_reference_source( 430 reference_info, sources_info) 431 432 @mock.patch('apache_beam.io.mongodbio.MongoClient') 433 def test_split_filtered_empty(self, mock_client): 434 # filtering doesn't match any documents 435 filtered_mongo_source = self._create_source( 436 filter={'x': { 437 '$lt': 0 438 }}, bucket_auto=self.bucket_auto) 439 440 mock_client.return_value = _MockMongoClient(self._docs) 441 for size_mb, (bucket_auto_count, split_vector_count) in [(1, (1, 5)), 442 (2, (1, 3)), 443 (10, (1, 1))]: 444 size = size_mb * 1024 * 1024 445 splits = list( 446 filtered_mongo_source.split( 447 start_position=None, stop_position=None, 448 desired_bundle_size=size)) 449 450 if self.bucket_auto: 451 # Note: if filter matches no docs - one split covers entire range 452 self.assertEqual(len(splits), bucket_auto_count) 453 else: 454 # Note: splitVector mode does not respect filter 455 self.assertEqual(len(splits), split_vector_count) 456 reference_info = ( 457 filtered_mongo_source, 458 # range to match no documents: 459 self._increment_id(self._docs[-1]['_id'], 1), 460 self._increment_id(self._docs[-1]['_id'], 2), 461 ) 462 sources_info = ([ 463 (split.source, split.start_position, split.stop_position) 464 for split in splits 465 ]) 466 source_test_utils.assert_sources_equal_reference_source( 467 reference_info, sources_info) 468 469 @mock.patch('apache_beam.io.mongodbio.MongoClient') 470 def test_dynamic_work_rebalancing(self, mock_client): 471 mock_client.return_value = _MockMongoClient(self._docs) 472 splits = list( 473 self.mongo_source.split(desired_bundle_size=3000 * 1024 * 1024)) 474 assert len(splits) == 1 475 source_test_utils.assert_split_at_fraction_exhaustive( 476 splits[0].source, splits[0].start_position, splits[0].stop_position) 477 478 @mock.patch('apache_beam.io.mongodbio.MongoClient') 479 def test_get_range_tracker(self, mock_client): 480 mock_client.return_value = _MockMongoClient(self._docs) 481 if self._ids == OBJECT_IDS: 482 self.assertIsInstance( 483 self.mongo_source.get_range_tracker(None, None), 484 _ObjectIdRangeTracker, 485 ) 486 elif self._ids == INT_IDS: 487 self.assertIsInstance( 488 self.mongo_source.get_range_tracker(None, None), 489 OffsetRangeTracker, 490 ) 491 elif self._ids == STR_IDS_1: 492 self.assertIsInstance( 493 self.mongo_source.get_range_tracker(None, None), 494 LexicographicKeyRangeTracker, 495 ) 496 497 @mock.patch('apache_beam.io.mongodbio.MongoClient') 498 def test_read(self, mock_client): 499 mock_tracker = mock.MagicMock() 500 test_cases = [ 501 { 502 # range covers the first(inclusive) to third(exclusive) documents 503 'start': self._ids[0], 504 'stop': self._ids[2], 505 'expected': self._docs[0:2] 506 }, 507 { 508 # range covers from the first to the third documents 509 'start': self.min_id, # smallest possible id 510 'stop': self._ids[2], 511 'expected': self._docs[0:2] 512 }, 513 { 514 # range covers from the third to last documents 515 'start': self._ids[2], 516 'stop': self.max_id, # largest possible id 517 'expected': self._docs[2:] 518 }, 519 { 520 # range covers all documents 521 'start': self.min_id, 522 'stop': self.max_id, 523 'expected': self._docs 524 }, 525 { 526 # range doesn't include any document 527 'start': self._increment_id(self._ids[2], 1), 528 'stop': self._increment_id(self._ids[3], -1), 529 'expected': [] 530 }, 531 ] 532 mock_client.return_value = _MockMongoClient(self._docs) 533 for case in test_cases: 534 mock_tracker.start_position.return_value = case['start'] 535 mock_tracker.stop_position.return_value = case['stop'] 536 result = list(self.mongo_source.read(mock_tracker)) 537 self.assertListEqual(case['expected'], result) 538 539 def test_display_data(self): 540 data = self.mongo_source.display_data() 541 self.assertTrue('database' in data) 542 self.assertTrue('collection' in data) 543 544 def test_range_is_not_splittable(self): 545 self.assertTrue( 546 self.mongo_source._range_is_not_splittable( 547 _ObjectIdHelper.int_to_id(1), 548 _ObjectIdHelper.int_to_id(1), 549 )) 550 self.assertTrue( 551 self.mongo_source._range_is_not_splittable( 552 _ObjectIdHelper.int_to_id(1), 553 _ObjectIdHelper.int_to_id(2), 554 )) 555 self.assertFalse( 556 self.mongo_source._range_is_not_splittable( 557 _ObjectIdHelper.int_to_id(1), 558 _ObjectIdHelper.int_to_id(3), 559 )) 560 561 self.assertTrue(self.mongo_source._range_is_not_splittable(0, 0)) 562 self.assertTrue(self.mongo_source._range_is_not_splittable(0, 1)) 563 self.assertFalse(self.mongo_source._range_is_not_splittable(0, 2)) 564 565 self.assertTrue(self.mongo_source._range_is_not_splittable("AAA", "AAA")) 566 self.assertFalse( 567 self.mongo_source._range_is_not_splittable("AAA", "AAA\x00")) 568 self.assertFalse(self.mongo_source._range_is_not_splittable("AAA", "AAB")) 569 570 571 @parameterized_class(('bucket_auto', ), [(False, ), (True, )]) 572 class ReadFromMongoDBTest(unittest.TestCase): 573 @mock.patch('apache_beam.io.mongodbio.MongoClient') 574 def test_read_from_mongodb(self, mock_client): 575 documents = [{ 576 '_id': objectid.ObjectId(), 'x': i, 'selected': 1, 'unselected': 2 577 } for i in range(3)] 578 mock_client.return_value = _MockMongoClient(documents) 579 580 projection = ['x', 'selected'] 581 projected_documents = [{ 582 k: v 583 for k, v in e.items() if k in projection or k == '_id' 584 } for e in documents] 585 586 with TestPipeline() as p: 587 docs = p | 'ReadFromMongoDB' >> ReadFromMongoDB( 588 uri='mongodb://test', 589 db='db', 590 coll='collection', 591 projection=projection, 592 bucket_auto=self.bucket_auto) 593 assert_that(docs, equal_to(projected_documents)) 594 595 596 class GenerateObjectIdFnTest(unittest.TestCase): 597 def test_process(self): 598 with TestPipeline() as p: 599 output = ( 600 p | "Create" >> beam.Create([{ 601 'x': 1 602 }, { 603 'x': 2, '_id': 123 604 }]) 605 | "Generate ID" >> beam.ParDo(_GenerateObjectIdFn()) 606 | "Check" >> beam.Map(lambda x: '_id' in x)) 607 assert_that(output, equal_to([True] * 2)) 608 609 610 class WriteMongoFnTest(unittest.TestCase): 611 @mock.patch('apache_beam.io.mongodbio._MongoSink') 612 def test_process(self, mock_sink): 613 docs = [{'x': 1}, {'x': 2}, {'x': 3}] 614 with TestPipeline() as p: 615 _ = ( 616 p | "Create" >> beam.Create(docs) 617 | "Write" >> beam.ParDo(_WriteMongoFn(batch_size=2))) 618 p.run() 619 620 self.assertEqual( 621 2, mock_sink.return_value.__enter__.return_value.write.call_count) 622 623 def test_display_data(self): 624 data = _WriteMongoFn(batch_size=10).display_data() 625 self.assertEqual(10, data['batch_size']) 626 627 628 class MongoSinkTest(unittest.TestCase): 629 @mock.patch('apache_beam.io.mongodbio.MongoClient') 630 def test_write(self, mock_client): 631 docs = [{'x': 1}, {'x': 2}, {'x': 3}] 632 _MongoSink(uri='test', db='test', coll='test').write(docs) 633 self.assertTrue( 634 mock_client.return_value.__getitem__.return_value.__getitem__. 635 return_value.bulk_write.called) 636 637 638 class WriteToMongoDBTest(unittest.TestCase): 639 @mock.patch('apache_beam.io.mongodbio.MongoClient') 640 def test_write_to_mongodb_with_existing_id(self, mock_client): 641 _id = objectid.ObjectId() 642 docs = [{'x': 1, '_id': _id}] 643 expected_update = [ 644 ReplaceOne({'_id': _id}, { 645 'x': 1, '_id': _id 646 }, True, None) 647 ] 648 with TestPipeline() as p: 649 _ = ( 650 p | "Create" >> beam.Create(docs) 651 | "Write" >> WriteToMongoDB(db='test', coll='test')) 652 p.run() 653 mock_client.return_value.__getitem__.return_value.__getitem__. \ 654 return_value.bulk_write.assert_called_with(expected_update) 655 656 @mock.patch('apache_beam.io.mongodbio.MongoClient') 657 def test_write_to_mongodb_with_generated_id(self, mock_client): 658 docs = [{'x': 1}] 659 expected_update = [ 660 ReplaceOne({'_id': mock.ANY}, { 661 'x': 1, '_id': mock.ANY 662 }, True, None) 663 ] 664 with TestPipeline() as p: 665 _ = ( 666 p | "Create" >> beam.Create(docs) 667 | "Write" >> WriteToMongoDB(db='test', coll='test')) 668 p.run() 669 mock_client.return_value.__getitem__.return_value.__getitem__. \ 670 return_value.bulk_write.assert_called_with(expected_update) 671 672 673 class ObjectIdHelperTest(TestCase): 674 def test_conversion(self): 675 test_cases = [ 676 (objectid.ObjectId('000000000000000000000000'), 0), 677 (objectid.ObjectId('000000000000000100000000'), 2**32), 678 (objectid.ObjectId('0000000000000000ffffffff'), 2**32 - 1), 679 (objectid.ObjectId('000000010000000000000000'), 2**64), 680 (objectid.ObjectId('00000000ffffffffffffffff'), 2**64 - 1), 681 (objectid.ObjectId('ffffffffffffffffffffffff'), 2**96 - 1), 682 ] 683 for (_id, number) in test_cases: 684 self.assertEqual(_id, _ObjectIdHelper.int_to_id(number)) 685 self.assertEqual(number, _ObjectIdHelper.id_to_int(_id)) 686 687 # random tests 688 for _ in range(100): 689 _id = objectid.ObjectId() 690 number = int(_id.binary.hex(), 16) 691 self.assertEqual(_id, _ObjectIdHelper.int_to_id(number)) 692 self.assertEqual(number, _ObjectIdHelper.id_to_int(_id)) 693 694 def test_increment_id(self): 695 test_cases = [ 696 ( 697 objectid.ObjectId("000000000000000100000000"), 698 objectid.ObjectId("0000000000000000ffffffff"), 699 ), 700 ( 701 objectid.ObjectId("000000010000000000000000"), 702 objectid.ObjectId("00000000ffffffffffffffff"), 703 ), 704 ] 705 for first, second in test_cases: 706 self.assertEqual(second, _ObjectIdHelper.increment_id(first, -1)) 707 self.assertEqual(first, _ObjectIdHelper.increment_id(second, 1)) 708 709 for _ in range(100): 710 _id = objectid.ObjectId() 711 self.assertLess(_id, _ObjectIdHelper.increment_id(_id, 1)) 712 self.assertGreater(_id, _ObjectIdHelper.increment_id(_id, -1)) 713 714 715 class ObjectRangeTrackerTest(TestCase): 716 def test_fraction_position_conversion(self): 717 start_int = 0 718 stop_int = 2**96 - 1 719 start = _ObjectIdHelper.int_to_id(start_int) 720 stop = _ObjectIdHelper.int_to_id(stop_int) 721 test_cases = ([start_int, stop_int, 2**32, 2**32 - 1, 2**64, 2**64 - 1] + 722 [random.randint(start_int, stop_int) for _ in range(100)]) 723 tracker = _ObjectIdRangeTracker() 724 for pos in test_cases: 725 _id = _ObjectIdHelper.int_to_id(pos - start_int) 726 desired_fraction = (pos - start_int) / (stop_int - start_int) 727 self.assertAlmostEqual( 728 tracker.position_to_fraction(_id, start, stop), 729 desired_fraction, 730 places=20) 731 732 convert_id = tracker.fraction_to_position( 733 (pos - start_int) / (stop_int - start_int), start, stop) 734 # due to precision loss, the convert fraction is only gonna be close to 735 # original fraction. 736 convert_fraction = tracker.position_to_fraction(convert_id, start, stop) 737 738 self.assertGreater(convert_id, start) 739 self.assertLess(convert_id, stop) 740 self.assertAlmostEqual(convert_fraction, desired_fraction, places=20) 741 742 743 if __name__ == '__main__': 744 logging.getLogger().setLevel(logging.INFO) 745 unittest.main()