github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/mongodbio_test.py (about)

     1  # Licensed to the Apache Software Foundation (ASF) under one or more
     2  # contributor license agreements.  See the NOTICE file distributed with
     3  # this work for additional information regarding copyright ownership.
     4  # The ASF licenses this file to You under the Apache License, Version 2.0
     5  # (the "License"); you may not use this file except in compliance with
     6  # the License.  You may obtain a copy of the License at
     7  #
     8  #    http://www.apache.org/licenses/LICENSE-2.0
     9  #
    10  # Unless required by applicable law or agreed to in writing, software
    11  # distributed under the License is distributed on an "AS IS" BASIS,
    12  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  # See the License for the specific language governing permissions and
    14  # limitations under the License.
    15  #
    16  
    17  # pytype: skip-file
    18  
    19  import datetime
    20  import logging
    21  import random
    22  import unittest
    23  from typing import Union
    24  from unittest import TestCase
    25  
    26  import mock
    27  from bson import ObjectId
    28  from bson import objectid
    29  from parameterized import parameterized_class
    30  from pymongo import ASCENDING
    31  from pymongo import ReplaceOne
    32  
    33  import apache_beam as beam
    34  from apache_beam.io import ReadFromMongoDB
    35  from apache_beam.io import WriteToMongoDB
    36  from apache_beam.io import source_test_utils
    37  from apache_beam.io.mongodbio import _BoundedMongoSource
    38  from apache_beam.io.mongodbio import _GenerateObjectIdFn
    39  from apache_beam.io.mongodbio import _MongoSink
    40  from apache_beam.io.mongodbio import _ObjectIdHelper
    41  from apache_beam.io.mongodbio import _ObjectIdRangeTracker
    42  from apache_beam.io.mongodbio import _WriteMongoFn
    43  from apache_beam.io.range_trackers import LexicographicKeyRangeTracker
    44  from apache_beam.io.range_trackers import OffsetRangeTracker
    45  from apache_beam.testing.test_pipeline import TestPipeline
    46  from apache_beam.testing.util import assert_that
    47  from apache_beam.testing.util import equal_to
    48  
    49  
    50  class _MockMongoColl(object):
    51    """Fake mongodb collection cursor."""
    52    def __init__(self, docs):
    53      self.docs = docs
    54  
    55    def __getitem__(self, index):
    56      return self.docs[index]
    57  
    58    def __len__(self):
    59      return len(self.docs)
    60  
    61    @staticmethod
    62    def _make_filter(conditions):
    63      assert isinstance(conditions, dict)
    64      checks = []
    65      for field, value in conditions.items():
    66        if isinstance(value, dict):
    67          for op, val in value.items():
    68            if op == '$gte':
    69              op = '__ge__'
    70            elif op == '$lt':
    71              op = '__lt__'
    72            else:
    73              raise Exception('Operator "{0}" not supported.'.format(op))
    74            checks.append((field, op, val))
    75        else:
    76          checks.append((field, '__eq__', value))
    77  
    78      def func(doc):
    79        for field, op, value in checks:
    80          if not getattr(doc[field], op)(value):
    81            return False
    82        return True
    83  
    84      return func
    85  
    86    def _filter(self, filter):
    87      match = []
    88      if not filter:
    89        return self
    90      all_filters = []
    91      if '$and' in filter:
    92        for item in filter['$and']:
    93          all_filters.append(self._make_filter(item))
    94      else:
    95        all_filters.append(self._make_filter(filter))
    96  
    97      for doc in self.docs:
    98        if not all(check(doc) for check in all_filters):
    99          continue
   100        match.append(doc)
   101  
   102      return match
   103  
   104    @staticmethod
   105    def _projection(docs, projection=None):
   106      if projection:
   107        return [{k: v
   108                 for k, v in doc.items() if k in projection or k == '_id'}
   109                for doc in docs]
   110      return docs
   111  
   112    def find(self, filter=None, projection=None, **kwargs):
   113      return _MockMongoColl(self._projection(self._filter(filter), projection))
   114  
   115    def sort(self, sort_items):
   116      key, order = sort_items[0]
   117      self.docs = sorted(
   118          self.docs, key=lambda x: x[key], reverse=(order != ASCENDING))
   119      return self
   120  
   121    def limit(self, num):
   122      return _MockMongoColl(self.docs[0:num])
   123  
   124    def count_documents(self, filter):
   125      return len(self._filter(filter))
   126  
   127    def aggregate(self, pipeline, **kwargs):
   128      # Simulate $bucketAuto aggregate pipeline.
   129      # Example splits doc count for the total of 5 docs:
   130      #   - 1 bucket:  [5]
   131      #   - 2 buckets: [3, 2]
   132      #   - 3 buckets: [2, 2, 1]
   133      #   - 4 buckets: [2, 1, 1, 1]
   134      #   - 5 buckets: [1, 1, 1, 1, 1]
   135      match_step = next((step for step in pipeline if '$match' in step), None)
   136      bucket_auto_step = next(step for step in pipeline if '$bucketAuto' in step)
   137      if match_step is None:
   138        docs = self.docs
   139      else:
   140        docs = self.find(filter=match_step['$match'])
   141      doc_count = len(docs)
   142      bucket_count = min(bucket_auto_step['$bucketAuto']['buckets'], doc_count)
   143      # bucket_count ≠ 0
   144      bucket_len, remainder = divmod(doc_count, bucket_count)
   145      bucket_sizes = (
   146          remainder * [bucket_len + 1] +
   147          (bucket_count - remainder) * [bucket_len])
   148      buckets = []
   149      start = 0
   150      for bucket_size in bucket_sizes:
   151        stop = start + bucket_size
   152        if stop >= doc_count:
   153          # MongoDB: the last bucket's 'max' is inclusive
   154          stop = doc_count - 1
   155          count = stop - start + 1
   156        else:
   157          # non-last bucket's 'max' is exclusive and == next bucket's 'min'
   158          count = stop - start
   159        buckets.append({
   160            '_id': {
   161                'min': docs[start]['_id'],
   162                'max': docs[stop]['_id'],
   163            },
   164            'count': count
   165        })
   166        start = stop
   167  
   168      return buckets
   169  
   170  
   171  class _MockMongoDb(object):
   172    """Fake Mongo Db."""
   173    def __init__(self, docs):
   174      self.docs = docs
   175  
   176    def __getitem__(self, coll_name):
   177      return _MockMongoColl(self.docs)
   178  
   179    def command(self, command, *args, **kwargs):
   180      if command == 'collstats':
   181        return {'size': 5 * 1024 * 1024, 'avgObjSize': 1 * 1024 * 1024}
   182      if command == 'splitVector':
   183        return self.get_split_keys(command, *args, **kwargs)
   184  
   185    def get_split_keys(self, command, ns, min, max, maxChunkSize, **kwargs):
   186      # simulate mongo db splitVector command, return split keys base on chunk
   187      # size, assuming every doc is of size 1mb
   188      start_id = min['_id']
   189      end_id = max['_id']
   190      if start_id >= end_id:
   191        return []
   192      start_index = 0
   193      end_index = 0
   194      # get split range of [min, max]
   195      for doc in self.docs:
   196        if doc['_id'] < start_id:
   197          start_index += 1
   198        if doc['_id'] <= end_id:
   199          end_index += 1
   200        else:
   201          break
   202      # Return ids of elements in the range with chunk size skip and exclude
   203      # head element. For simplicity of tests every document is considered 1Mb
   204      # by default.
   205      return {
   206          'splitKeys': [{
   207              '_id': x['_id']
   208          } for x in self.docs[start_index:end_index:maxChunkSize]][1:]
   209      }
   210  
   211  
   212  class _MockMongoClient:
   213    def __init__(self, docs):
   214      self.docs = docs
   215  
   216    def __getitem__(self, db_name):
   217      return _MockMongoDb(self.docs)
   218  
   219    def __enter__(self):
   220      return self
   221  
   222    def __exit__(self, exc_type, exc_val, exc_tb):
   223      pass
   224  
   225  
   226  # Generate test data for MongoDB collections of different types
   227  OBJECT_IDS = [
   228      objectid.ObjectId.from_datetime(
   229          datetime.datetime(year=2020, month=i + 1, day=i + 1)) for i in range(5)
   230  ]
   231  
   232  INT_IDS = [n for n in range(5)]  # [0, 1, 2, 3, 4]
   233  
   234  STR_IDS_1 = [str(n) for n in range(5)]  # ['0', '1', '2', '3', '4']
   235  
   236  # ['aaaaa', 'bbbbb', 'ccccc', 'ddddd', 'eeeee']
   237  STR_IDS_2 = [chr(97 + n) * 5 for n in range(5)]
   238  
   239  # ['AAAAAAAAAAAAAAAAAAAA', 'BBBBBBBBBBBBBBBBBBBB', ..., 'EEEEEEEEEEEEEEEEEEEE']
   240  STR_IDS_3 = [chr(65 + n) * 20 for n in range(5)]
   241  
   242  
   243  @parameterized_class(('bucket_auto', '_ids', 'min_id', 'max_id'),
   244                       [
   245                           (
   246                               None,
   247                               OBJECT_IDS,
   248                               _ObjectIdHelper.int_to_id(0),
   249                               _ObjectIdHelper.int_to_id(2**96 - 1)),
   250                           (
   251                               True,
   252                               OBJECT_IDS,
   253                               _ObjectIdHelper.int_to_id(0),
   254                               _ObjectIdHelper.int_to_id(2**96 - 1)),
   255                           (
   256                               None,
   257                               INT_IDS,
   258                               0,
   259                               2**96 - 1,
   260                           ),
   261                           (
   262                               True,
   263                               INT_IDS,
   264                               0,
   265                               2**96 - 1,
   266                           ),
   267                           (
   268                               None,
   269                               STR_IDS_1,
   270                               chr(0),
   271                               chr(0x10ffff),
   272                           ),
   273                           (
   274                               True,
   275                               STR_IDS_1,
   276                               chr(0),
   277                               chr(0x10ffff),
   278                           ),
   279                           (
   280                               None,
   281                               STR_IDS_2,
   282                               chr(0),
   283                               chr(0x10ffff),
   284                           ),
   285                           (
   286                               True,
   287                               STR_IDS_2,
   288                               chr(0),
   289                               chr(0x10ffff),
   290                           ),
   291                           (
   292                               None,
   293                               STR_IDS_3,
   294                               chr(0),
   295                               chr(0x10ffff),
   296                           ),
   297                           (
   298                               True,
   299                               STR_IDS_3,
   300                               chr(0),
   301                               chr(0x10ffff),
   302                           ),
   303                       ])
   304  class MongoSourceTest(unittest.TestCase):
   305    @mock.patch('apache_beam.io.mongodbio.MongoClient')
   306    def setUp(self, mock_client):
   307      self._docs = [{'_id': self._ids[i], 'x': i} for i in range(len(self._ids))]
   308      mock_client.return_value = _MockMongoClient(self._docs)
   309  
   310      self.mongo_source = self._create_source(bucket_auto=self.bucket_auto)
   311  
   312    @staticmethod
   313    def _create_source(filter=None, bucket_auto=None):
   314      kwargs = {}
   315      if filter is not None:
   316        kwargs['filter'] = filter
   317      if bucket_auto is not None:
   318        kwargs['bucket_auto'] = bucket_auto
   319      return _BoundedMongoSource('mongodb://test', 'testdb', 'testcoll', **kwargs)
   320  
   321    def _increment_id(
   322        self,
   323        _id: Union[ObjectId, int, str],
   324        inc: int,
   325    ) -> Union[ObjectId, int, str]:
   326      """Helper method to increment `_id` of different types."""
   327  
   328      if isinstance(_id, ObjectId):
   329        return _ObjectIdHelper.increment_id(_id, inc)
   330  
   331      if isinstance(_id, int):
   332        return _id + inc
   333  
   334      if isinstance(_id, str):
   335        index = self._ids.index(_id) + inc
   336        if index <= 0:
   337          return self._ids[0]
   338        if index >= len(self._ids):
   339          return self._ids[-1]
   340        return self._ids[index]
   341  
   342    @mock.patch('apache_beam.io.mongodbio.MongoClient')
   343    def test_estimate_size(self, mock_client):
   344      mock_client.return_value = _MockMongoClient(self._docs)
   345      self.assertEqual(self.mongo_source.estimate_size(), 5 * 1024 * 1024)
   346  
   347    @mock.patch('apache_beam.io.mongodbio.MongoClient')
   348    def test_estimate_average_document_size(self, mock_client):
   349      mock_client.return_value = _MockMongoClient(self._docs)
   350      self.assertEqual(
   351          self.mongo_source._estimate_average_document_size(), 1 * 1024 * 1024)
   352  
   353    @mock.patch('apache_beam.io.mongodbio.MongoClient')
   354    def test_split(self, mock_client):
   355      mock_client.return_value = _MockMongoClient(self._docs)
   356      for size_mb, expected_split_count in [(0.5, 5), (1, 5), (2, 3), (10, 1)]:
   357        size = size_mb * 1024 * 1024
   358        splits = list(
   359            self.mongo_source.split(
   360                start_position=None, stop_position=None,
   361                desired_bundle_size=size))
   362  
   363        self.assertEqual(len(splits), expected_split_count)
   364        reference_info = (self.mongo_source, None, None)
   365        sources_info = ([
   366            (split.source, split.start_position, split.stop_position)
   367            for split in splits
   368        ])
   369        source_test_utils.assert_sources_equal_reference_source(
   370            reference_info, sources_info)
   371  
   372    @mock.patch('apache_beam.io.mongodbio.MongoClient')
   373    def test_split_single_document(self, mock_client):
   374      mock_client.return_value = _MockMongoClient(self._docs[0:1])
   375      for size_mb in [1, 5]:
   376        size = size_mb * 1024 * 1024
   377        splits = list(
   378            self.mongo_source.split(
   379                start_position=None, stop_position=None,
   380                desired_bundle_size=size))
   381        self.assertEqual(len(splits), 1)
   382        _id = self._docs[0]['_id']
   383        assert _id == splits[0].start_position
   384        assert _id <= splits[0].stop_position
   385        if isinstance(_id, (ObjectId, int)):
   386          # We can unambiguously determine next `_id`
   387          assert self._increment_id(_id, 1) == splits[0].stop_position
   388  
   389    @mock.patch('apache_beam.io.mongodbio.MongoClient')
   390    def test_split_no_documents(self, mock_client):
   391      mock_client.return_value = _MockMongoClient([])
   392      with self.assertRaises(ValueError) as cm:
   393        list(
   394            self.mongo_source.split(
   395                start_position=None,
   396                stop_position=None,
   397                desired_bundle_size=1024 * 1024))
   398      self.assertEqual(str(cm.exception), 'Empty Mongodb collection')
   399  
   400    @mock.patch('apache_beam.io.mongodbio.MongoClient')
   401    def test_split_filtered(self, mock_client):
   402      # filtering 2 documents: 2 <= 'x' < 4
   403      filtered_mongo_source = self._create_source(
   404          filter={'x': {
   405              '$gte': 2, '$lt': 4
   406          }}, bucket_auto=self.bucket_auto)
   407  
   408      mock_client.return_value = _MockMongoClient(self._docs)
   409      for size_mb, (bucket_auto_count, split_vector_count) in [(1, (2, 5)),
   410                                                               (2, (1, 3)),
   411                                                               (10, (1, 1))]:
   412        size = size_mb * 1024 * 1024
   413        splits = list(
   414            filtered_mongo_source.split(
   415                start_position=None, stop_position=None,
   416                desired_bundle_size=size))
   417  
   418        if self.bucket_auto:
   419          self.assertEqual(len(splits), bucket_auto_count)
   420        else:
   421          # Note: splitVector mode does not respect filter
   422          self.assertEqual(len(splits), split_vector_count)
   423        reference_info = (
   424            filtered_mongo_source, self._docs[2]['_id'], self._docs[4]['_id'])
   425        sources_info = ([
   426            (split.source, split.start_position, split.stop_position)
   427            for split in splits
   428        ])
   429        source_test_utils.assert_sources_equal_reference_source(
   430            reference_info, sources_info)
   431  
   432    @mock.patch('apache_beam.io.mongodbio.MongoClient')
   433    def test_split_filtered_empty(self, mock_client):
   434      # filtering doesn't match any documents
   435      filtered_mongo_source = self._create_source(
   436          filter={'x': {
   437              '$lt': 0
   438          }}, bucket_auto=self.bucket_auto)
   439  
   440      mock_client.return_value = _MockMongoClient(self._docs)
   441      for size_mb, (bucket_auto_count, split_vector_count) in [(1, (1, 5)),
   442                                                               (2, (1, 3)),
   443                                                               (10, (1, 1))]:
   444        size = size_mb * 1024 * 1024
   445        splits = list(
   446            filtered_mongo_source.split(
   447                start_position=None, stop_position=None,
   448                desired_bundle_size=size))
   449  
   450        if self.bucket_auto:
   451          # Note: if filter matches no docs - one split covers entire range
   452          self.assertEqual(len(splits), bucket_auto_count)
   453        else:
   454          # Note: splitVector mode does not respect filter
   455          self.assertEqual(len(splits), split_vector_count)
   456        reference_info = (
   457            filtered_mongo_source,
   458            # range to match no documents:
   459            self._increment_id(self._docs[-1]['_id'], 1),
   460            self._increment_id(self._docs[-1]['_id'], 2),
   461        )
   462        sources_info = ([
   463            (split.source, split.start_position, split.stop_position)
   464            for split in splits
   465        ])
   466        source_test_utils.assert_sources_equal_reference_source(
   467            reference_info, sources_info)
   468  
   469    @mock.patch('apache_beam.io.mongodbio.MongoClient')
   470    def test_dynamic_work_rebalancing(self, mock_client):
   471      mock_client.return_value = _MockMongoClient(self._docs)
   472      splits = list(
   473          self.mongo_source.split(desired_bundle_size=3000 * 1024 * 1024))
   474      assert len(splits) == 1
   475      source_test_utils.assert_split_at_fraction_exhaustive(
   476          splits[0].source, splits[0].start_position, splits[0].stop_position)
   477  
   478    @mock.patch('apache_beam.io.mongodbio.MongoClient')
   479    def test_get_range_tracker(self, mock_client):
   480      mock_client.return_value = _MockMongoClient(self._docs)
   481      if self._ids == OBJECT_IDS:
   482        self.assertIsInstance(
   483            self.mongo_source.get_range_tracker(None, None),
   484            _ObjectIdRangeTracker,
   485        )
   486      elif self._ids == INT_IDS:
   487        self.assertIsInstance(
   488            self.mongo_source.get_range_tracker(None, None),
   489            OffsetRangeTracker,
   490        )
   491      elif self._ids == STR_IDS_1:
   492        self.assertIsInstance(
   493            self.mongo_source.get_range_tracker(None, None),
   494            LexicographicKeyRangeTracker,
   495        )
   496  
   497    @mock.patch('apache_beam.io.mongodbio.MongoClient')
   498    def test_read(self, mock_client):
   499      mock_tracker = mock.MagicMock()
   500      test_cases = [
   501          {
   502              # range covers the first(inclusive) to third(exclusive) documents
   503              'start': self._ids[0],
   504              'stop': self._ids[2],
   505              'expected': self._docs[0:2]
   506          },
   507          {
   508              # range covers from the first to the third documents
   509              'start': self.min_id,  # smallest possible id
   510              'stop': self._ids[2],
   511              'expected': self._docs[0:2]
   512          },
   513          {
   514              # range covers from the third to last documents
   515              'start': self._ids[2],
   516              'stop': self.max_id,  # largest possible id
   517              'expected': self._docs[2:]
   518          },
   519          {
   520              # range covers all documents
   521              'start': self.min_id,
   522              'stop': self.max_id,
   523              'expected': self._docs
   524          },
   525          {
   526              # range doesn't include any document
   527              'start': self._increment_id(self._ids[2], 1),
   528              'stop': self._increment_id(self._ids[3], -1),
   529              'expected': []
   530          },
   531      ]
   532      mock_client.return_value = _MockMongoClient(self._docs)
   533      for case in test_cases:
   534        mock_tracker.start_position.return_value = case['start']
   535        mock_tracker.stop_position.return_value = case['stop']
   536        result = list(self.mongo_source.read(mock_tracker))
   537        self.assertListEqual(case['expected'], result)
   538  
   539    def test_display_data(self):
   540      data = self.mongo_source.display_data()
   541      self.assertTrue('database' in data)
   542      self.assertTrue('collection' in data)
   543  
   544    def test_range_is_not_splittable(self):
   545      self.assertTrue(
   546          self.mongo_source._range_is_not_splittable(
   547              _ObjectIdHelper.int_to_id(1),
   548              _ObjectIdHelper.int_to_id(1),
   549          ))
   550      self.assertTrue(
   551          self.mongo_source._range_is_not_splittable(
   552              _ObjectIdHelper.int_to_id(1),
   553              _ObjectIdHelper.int_to_id(2),
   554          ))
   555      self.assertFalse(
   556          self.mongo_source._range_is_not_splittable(
   557              _ObjectIdHelper.int_to_id(1),
   558              _ObjectIdHelper.int_to_id(3),
   559          ))
   560  
   561      self.assertTrue(self.mongo_source._range_is_not_splittable(0, 0))
   562      self.assertTrue(self.mongo_source._range_is_not_splittable(0, 1))
   563      self.assertFalse(self.mongo_source._range_is_not_splittable(0, 2))
   564  
   565      self.assertTrue(self.mongo_source._range_is_not_splittable("AAA", "AAA"))
   566      self.assertFalse(
   567          self.mongo_source._range_is_not_splittable("AAA", "AAA\x00"))
   568      self.assertFalse(self.mongo_source._range_is_not_splittable("AAA", "AAB"))
   569  
   570  
   571  @parameterized_class(('bucket_auto', ), [(False, ), (True, )])
   572  class ReadFromMongoDBTest(unittest.TestCase):
   573    @mock.patch('apache_beam.io.mongodbio.MongoClient')
   574    def test_read_from_mongodb(self, mock_client):
   575      documents = [{
   576          '_id': objectid.ObjectId(), 'x': i, 'selected': 1, 'unselected': 2
   577      } for i in range(3)]
   578      mock_client.return_value = _MockMongoClient(documents)
   579  
   580      projection = ['x', 'selected']
   581      projected_documents = [{
   582          k: v
   583          for k, v in e.items() if k in projection or k == '_id'
   584      } for e in documents]
   585  
   586      with TestPipeline() as p:
   587        docs = p | 'ReadFromMongoDB' >> ReadFromMongoDB(
   588            uri='mongodb://test',
   589            db='db',
   590            coll='collection',
   591            projection=projection,
   592            bucket_auto=self.bucket_auto)
   593        assert_that(docs, equal_to(projected_documents))
   594  
   595  
   596  class GenerateObjectIdFnTest(unittest.TestCase):
   597    def test_process(self):
   598      with TestPipeline() as p:
   599        output = (
   600            p | "Create" >> beam.Create([{
   601                'x': 1
   602            }, {
   603                'x': 2, '_id': 123
   604            }])
   605            | "Generate ID" >> beam.ParDo(_GenerateObjectIdFn())
   606            | "Check" >> beam.Map(lambda x: '_id' in x))
   607        assert_that(output, equal_to([True] * 2))
   608  
   609  
   610  class WriteMongoFnTest(unittest.TestCase):
   611    @mock.patch('apache_beam.io.mongodbio._MongoSink')
   612    def test_process(self, mock_sink):
   613      docs = [{'x': 1}, {'x': 2}, {'x': 3}]
   614      with TestPipeline() as p:
   615        _ = (
   616            p | "Create" >> beam.Create(docs)
   617            | "Write" >> beam.ParDo(_WriteMongoFn(batch_size=2)))
   618        p.run()
   619  
   620        self.assertEqual(
   621            2, mock_sink.return_value.__enter__.return_value.write.call_count)
   622  
   623    def test_display_data(self):
   624      data = _WriteMongoFn(batch_size=10).display_data()
   625      self.assertEqual(10, data['batch_size'])
   626  
   627  
   628  class MongoSinkTest(unittest.TestCase):
   629    @mock.patch('apache_beam.io.mongodbio.MongoClient')
   630    def test_write(self, mock_client):
   631      docs = [{'x': 1}, {'x': 2}, {'x': 3}]
   632      _MongoSink(uri='test', db='test', coll='test').write(docs)
   633      self.assertTrue(
   634          mock_client.return_value.__getitem__.return_value.__getitem__.
   635          return_value.bulk_write.called)
   636  
   637  
   638  class WriteToMongoDBTest(unittest.TestCase):
   639    @mock.patch('apache_beam.io.mongodbio.MongoClient')
   640    def test_write_to_mongodb_with_existing_id(self, mock_client):
   641      _id = objectid.ObjectId()
   642      docs = [{'x': 1, '_id': _id}]
   643      expected_update = [
   644          ReplaceOne({'_id': _id}, {
   645              'x': 1, '_id': _id
   646          }, True, None)
   647      ]
   648      with TestPipeline() as p:
   649        _ = (
   650            p | "Create" >> beam.Create(docs)
   651            | "Write" >> WriteToMongoDB(db='test', coll='test'))
   652        p.run()
   653        mock_client.return_value.__getitem__.return_value.__getitem__. \
   654          return_value.bulk_write.assert_called_with(expected_update)
   655  
   656    @mock.patch('apache_beam.io.mongodbio.MongoClient')
   657    def test_write_to_mongodb_with_generated_id(self, mock_client):
   658      docs = [{'x': 1}]
   659      expected_update = [
   660          ReplaceOne({'_id': mock.ANY}, {
   661              'x': 1, '_id': mock.ANY
   662          }, True, None)
   663      ]
   664      with TestPipeline() as p:
   665        _ = (
   666            p | "Create" >> beam.Create(docs)
   667            | "Write" >> WriteToMongoDB(db='test', coll='test'))
   668        p.run()
   669        mock_client.return_value.__getitem__.return_value.__getitem__. \
   670          return_value.bulk_write.assert_called_with(expected_update)
   671  
   672  
   673  class ObjectIdHelperTest(TestCase):
   674    def test_conversion(self):
   675      test_cases = [
   676          (objectid.ObjectId('000000000000000000000000'), 0),
   677          (objectid.ObjectId('000000000000000100000000'), 2**32),
   678          (objectid.ObjectId('0000000000000000ffffffff'), 2**32 - 1),
   679          (objectid.ObjectId('000000010000000000000000'), 2**64),
   680          (objectid.ObjectId('00000000ffffffffffffffff'), 2**64 - 1),
   681          (objectid.ObjectId('ffffffffffffffffffffffff'), 2**96 - 1),
   682      ]
   683      for (_id, number) in test_cases:
   684        self.assertEqual(_id, _ObjectIdHelper.int_to_id(number))
   685        self.assertEqual(number, _ObjectIdHelper.id_to_int(_id))
   686  
   687      # random tests
   688      for _ in range(100):
   689        _id = objectid.ObjectId()
   690        number = int(_id.binary.hex(), 16)
   691        self.assertEqual(_id, _ObjectIdHelper.int_to_id(number))
   692        self.assertEqual(number, _ObjectIdHelper.id_to_int(_id))
   693  
   694    def test_increment_id(self):
   695      test_cases = [
   696          (
   697              objectid.ObjectId("000000000000000100000000"),
   698              objectid.ObjectId("0000000000000000ffffffff"),
   699          ),
   700          (
   701              objectid.ObjectId("000000010000000000000000"),
   702              objectid.ObjectId("00000000ffffffffffffffff"),
   703          ),
   704      ]
   705      for first, second in test_cases:
   706        self.assertEqual(second, _ObjectIdHelper.increment_id(first, -1))
   707        self.assertEqual(first, _ObjectIdHelper.increment_id(second, 1))
   708  
   709      for _ in range(100):
   710        _id = objectid.ObjectId()
   711        self.assertLess(_id, _ObjectIdHelper.increment_id(_id, 1))
   712        self.assertGreater(_id, _ObjectIdHelper.increment_id(_id, -1))
   713  
   714  
   715  class ObjectRangeTrackerTest(TestCase):
   716    def test_fraction_position_conversion(self):
   717      start_int = 0
   718      stop_int = 2**96 - 1
   719      start = _ObjectIdHelper.int_to_id(start_int)
   720      stop = _ObjectIdHelper.int_to_id(stop_int)
   721      test_cases = ([start_int, stop_int, 2**32, 2**32 - 1, 2**64, 2**64 - 1] +
   722                    [random.randint(start_int, stop_int) for _ in range(100)])
   723      tracker = _ObjectIdRangeTracker()
   724      for pos in test_cases:
   725        _id = _ObjectIdHelper.int_to_id(pos - start_int)
   726        desired_fraction = (pos - start_int) / (stop_int - start_int)
   727        self.assertAlmostEqual(
   728            tracker.position_to_fraction(_id, start, stop),
   729            desired_fraction,
   730            places=20)
   731  
   732        convert_id = tracker.fraction_to_position(
   733            (pos - start_int) / (stop_int - start_int), start, stop)
   734        # due to precision loss, the convert fraction is only gonna be close to
   735        # original fraction.
   736        convert_fraction = tracker.position_to_fraction(convert_id, start, stop)
   737  
   738        self.assertGreater(convert_id, start)
   739        self.assertLess(convert_id, stop)
   740        self.assertAlmostEqual(convert_fraction, desired_fraction, places=20)
   741  
   742  
   743  if __name__ == '__main__':
   744    logging.getLogger().setLevel(logging.INFO)
   745    unittest.main()