github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/mongodbio.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """This module implements IO classes to read and write data on MongoDB.
    19  
    20  
    21  Read from MongoDB
    22  -----------------
    23  :class:`ReadFromMongoDB` is a ``PTransform`` that reads from a configured
    24  MongoDB source and returns a ``PCollection`` of dict representing MongoDB
    25  documents.
    26  To configure MongoDB source, the URI to connect to MongoDB server, database
    27  name, collection name needs to be provided.
    28  
    29  Example usage::
    30  
    31    pipeline | ReadFromMongoDB(uri='mongodb://localhost:27017',
    32                               db='testdb',
    33                               coll='input')
    34  
    35  To read from MongoDB Atlas, use ``bucket_auto`` option to enable
    36  ``@bucketAuto`` MongoDB aggregation instead of ``splitVector``
    37  command which is a high-privilege function that cannot be assigned
    38  to any user in Atlas.
    39  
    40  Example usage::
    41  
    42    pipeline | ReadFromMongoDB(uri='mongodb+srv://user:pwd@cluster0.mongodb.net',
    43                               db='testdb',
    44                               coll='input',
    45                               bucket_auto=True)
    46  
    47  
    48  Write to MongoDB:
    49  -----------------
    50  :class:`WriteToMongoDB` is a ``PTransform`` that writes MongoDB documents to
    51  configured sink, and the write is conducted through a mongodb bulk_write of
    52  ``ReplaceOne`` operations. If the document's _id field already existed in the
    53  MongoDB collection, it results in an overwrite, otherwise, a new document
    54  will be inserted.
    55  
    56  Example usage::
    57  
    58    pipeline | WriteToMongoDB(uri='mongodb://localhost:27017',
    59                              db='testdb',
    60                              coll='output',
    61                              batch_size=10)
    62  
    63  
    64  No backward compatibility guarantees. Everything in this module is experimental.
    65  """
    66  
    67  # pytype: skip-file
    68  
    69  import itertools
    70  import json
    71  import logging
    72  import math
    73  import struct
    74  from typing import Union
    75  
    76  import apache_beam as beam
    77  from apache_beam.io import iobase
    78  from apache_beam.io.range_trackers import LexicographicKeyRangeTracker
    79  from apache_beam.io.range_trackers import OffsetRangeTracker
    80  from apache_beam.io.range_trackers import OrderedPositionRangeTracker
    81  from apache_beam.transforms import DoFn
    82  from apache_beam.transforms import PTransform
    83  from apache_beam.transforms import Reshuffle
    84  
    85  _LOGGER = logging.getLogger(__name__)
    86  
    87  try:
    88    # Mongodb has its own bundled bson, which is not compatible with bson package.
    89    # (https://github.com/py-bson/bson/issues/82). Try to import objectid and if
    90    # it fails because bson package is installed, MongoDB IO will not work but at
    91    # least rest of the SDK will work.
    92    from bson import json_util
    93    from bson import objectid
    94    from bson.objectid import ObjectId
    95  
    96    # pymongo also internally depends on bson.
    97    from pymongo import ASCENDING
    98    from pymongo import DESCENDING
    99    from pymongo import MongoClient
   100    from pymongo import ReplaceOne
   101  except ImportError:
   102    objectid = None
   103    json_util = None
   104    ObjectId = None
   105    ASCENDING = 1
   106    DESCENDING = -1
   107    MongoClient = None
   108    ReplaceOne = None
   109    _LOGGER.warning("Could not find a compatible bson package.")
   110  
   111  __all__ = ["ReadFromMongoDB", "WriteToMongoDB"]
   112  
   113  
   114  class ReadFromMongoDB(PTransform):
   115    """A ``PTransform`` to read MongoDB documents into a ``PCollection``."""
   116    def __init__(
   117        self,
   118        uri="mongodb://localhost:27017",
   119        db=None,
   120        coll=None,
   121        filter=None,
   122        projection=None,
   123        extra_client_params=None,
   124        bucket_auto=False,
   125    ):
   126      """Initialize a :class:`ReadFromMongoDB`
   127  
   128      Args:
   129        uri (str): The MongoDB connection string following the URI format.
   130        db (str): The MongoDB database name.
   131        coll (str): The MongoDB collection name.
   132        filter: A `bson.SON
   133          <https://api.mongodb.com/python/current/api/bson/son.html>`_ object
   134          specifying elements which must be present for a document to be included
   135          in the result set.
   136        projection: A list of field names that should be returned in the result
   137          set or a dict specifying the fields to include or exclude.
   138        extra_client_params(dict): Optional `MongoClient
   139          <https://api.mongodb.com/python/current/api/pymongo/mongo_client.html>`_
   140          parameters.
   141        bucket_auto (bool): If :data:`True`, use MongoDB `$bucketAuto` aggregation
   142          to split collection into bundles instead of `splitVector` command,
   143          which does not work with MongoDB Atlas.
   144          If :data:`False` (the default), use `splitVector` command for bundling.
   145  
   146      Returns:
   147        :class:`~apache_beam.transforms.ptransform.PTransform`
   148      """
   149      if extra_client_params is None:
   150        extra_client_params = {}
   151      if not isinstance(db, str):
   152        raise ValueError("ReadFromMongDB db param must be specified as a string")
   153      if not isinstance(coll, str):
   154        raise ValueError(
   155            "ReadFromMongDB coll param must be specified as a string")
   156      self._mongo_source = _BoundedMongoSource(
   157          uri=uri,
   158          db=db,
   159          coll=coll,
   160          filter=filter,
   161          projection=projection,
   162          extra_client_params=extra_client_params,
   163          bucket_auto=bucket_auto,
   164      )
   165  
   166    def expand(self, pcoll):
   167      return pcoll | iobase.Read(self._mongo_source)
   168  
   169  
   170  class _ObjectIdRangeTracker(OrderedPositionRangeTracker):
   171    """RangeTracker for tracking mongodb _id of bson ObjectId type."""
   172    def position_to_fraction(
   173        self,
   174        pos: ObjectId,
   175        start: ObjectId,
   176        end: ObjectId,
   177    ):
   178      """Returns the fraction of keys in the range [start, end) that
   179      are less than the given key.
   180      """
   181      pos_number = _ObjectIdHelper.id_to_int(pos)
   182      start_number = _ObjectIdHelper.id_to_int(start)
   183      end_number = _ObjectIdHelper.id_to_int(end)
   184      return (pos_number - start_number) / (end_number - start_number)
   185  
   186    def fraction_to_position(
   187        self,
   188        fraction: float,
   189        start: ObjectId,
   190        end: ObjectId,
   191    ):
   192      """Converts a fraction between 0 and 1
   193      to a position between start and end.
   194      """
   195      start_number = _ObjectIdHelper.id_to_int(start)
   196      end_number = _ObjectIdHelper.id_to_int(end)
   197      total = end_number - start_number
   198      pos = int(total * fraction + start_number)
   199      # make sure split position is larger than start position and smaller than
   200      # end position.
   201      if pos <= start_number:
   202        return _ObjectIdHelper.increment_id(start, 1)
   203  
   204      if pos >= end_number:
   205        return _ObjectIdHelper.increment_id(end, -1)
   206  
   207      return _ObjectIdHelper.int_to_id(pos)
   208  
   209  
   210  class _BoundedMongoSource(iobase.BoundedSource):
   211    """A MongoDB source that reads a finite amount of input records.
   212  
   213    This class defines following operations which can be used to read
   214    MongoDB source efficiently.
   215  
   216    * Size estimation - method ``estimate_size()`` may return an accurate
   217      estimation in bytes for the size of the source.
   218    * Splitting into bundles of a given size - method ``split()`` can be used to
   219      split the source into a set of sub-sources (bundles) based on a desired
   220      bundle size.
   221    * Getting a RangeTracker - method ``get_range_tracker()`` should return a
   222      ``RangeTracker`` object for a given position range for the position type
   223      of the records returned by the source.
   224    * Reading the data - method ``read()`` can be used to read data from the
   225      source while respecting the boundaries defined by a given
   226      ``RangeTracker``.
   227  
   228    A runner will perform reading the source in two steps.
   229  
   230    (1) Method ``get_range_tracker()`` will be invoked with start and end
   231        positions to obtain a ``RangeTracker`` for the range of positions the
   232        runner intends to read. Source must define a default initial start and end
   233        position range. These positions must be used if the start and/or end
   234        positions passed to the method ``get_range_tracker()`` are ``None``
   235    (2) Method read() will be invoked with the ``RangeTracker`` obtained in the
   236        previous step.
   237  
   238    **Mutability**
   239  
   240    A ``_BoundedMongoSource`` object should not be mutated while
   241    its methods (for example, ``read()``) are being invoked by a runner. Runner
   242    implementations may invoke methods of ``_BoundedMongoSource`` objects through
   243    multi-threaded and/or reentrant execution modes.
   244    """
   245    def __init__(
   246        self,
   247        uri=None,
   248        db=None,
   249        coll=None,
   250        filter=None,
   251        projection=None,
   252        extra_client_params=None,
   253        bucket_auto=False,
   254    ):
   255      if extra_client_params is None:
   256        extra_client_params = {}
   257      if filter is None:
   258        filter = {}
   259      self.uri = uri
   260      self.db = db
   261      self.coll = coll
   262      self.filter = filter
   263      self.projection = projection
   264      self.spec = extra_client_params
   265      self.bucket_auto = bucket_auto
   266  
   267    def estimate_size(self):
   268      with MongoClient(self.uri, **self.spec) as client:
   269        return client[self.db].command("collstats", self.coll).get("size")
   270  
   271    def _estimate_average_document_size(self):
   272      with MongoClient(self.uri, **self.spec) as client:
   273        return client[self.db].command("collstats", self.coll).get("avgObjSize")
   274  
   275    def split(
   276        self,
   277        desired_bundle_size: int,
   278        start_position: Union[int, str, bytes, ObjectId] = None,
   279        stop_position: Union[int, str, bytes, ObjectId] = None,
   280    ):
   281      """Splits the source into a set of bundles.
   282  
   283      Bundles should be approximately of size ``desired_bundle_size`` bytes.
   284  
   285      Args:
   286        desired_bundle_size: the desired size (in bytes) of the bundles returned.
   287        start_position: if specified the given position must be used as the
   288                        starting position of the first bundle.
   289        stop_position: if specified the given position must be used as the ending
   290                       position of the last bundle.
   291      Returns:
   292        an iterator of objects of type 'SourceBundle' that gives information about
   293        the generated bundles.
   294      """
   295  
   296      desired_bundle_size_in_mb = desired_bundle_size // 1024 // 1024
   297  
   298      # for desired bundle size, if desired chunk size smaller than 1mb, use
   299      # MongoDB default split size of 1mb.
   300      desired_bundle_size_in_mb = max(desired_bundle_size_in_mb, 1)
   301  
   302      is_initial_split = start_position is None and stop_position is None
   303      start_position, stop_position = self._replace_none_positions(
   304        start_position, stop_position
   305      )
   306  
   307      if self.bucket_auto:
   308        # Use $bucketAuto for bundling
   309        split_keys = []
   310        weights = []
   311        for bucket in self._get_auto_buckets(
   312            desired_bundle_size_in_mb,
   313            start_position,
   314            stop_position,
   315            is_initial_split,
   316        ):
   317          split_keys.append({"_id": bucket["_id"]["max"]})
   318          weights.append(bucket["count"])
   319      else:
   320        # Use splitVector for bundling
   321        split_keys = self._get_split_keys(
   322            desired_bundle_size_in_mb, start_position, stop_position)
   323        weights = itertools.cycle((desired_bundle_size_in_mb, ))
   324  
   325      bundle_start = start_position
   326      for split_key_id, weight in zip(split_keys, weights):
   327        if bundle_start >= stop_position:
   328          break
   329        bundle_end = min(stop_position, split_key_id["_id"])
   330        yield iobase.SourceBundle(
   331            weight=weight,
   332            source=self,
   333            start_position=bundle_start,
   334            stop_position=bundle_end,
   335        )
   336        bundle_start = bundle_end
   337      # add range of last split_key to stop_position
   338      if bundle_start < stop_position:
   339        # bucket_auto mode can come here if not split due to single document
   340        weight = 1 if self.bucket_auto else desired_bundle_size_in_mb
   341        yield iobase.SourceBundle(
   342            weight=weight,
   343            source=self,
   344            start_position=bundle_start,
   345            stop_position=stop_position,
   346        )
   347  
   348    def get_range_tracker(
   349        self,
   350        start_position: Union[int, str, ObjectId] = None,
   351        stop_position: Union[int, str, ObjectId] = None,
   352    ) -> Union[
   353        _ObjectIdRangeTracker, OffsetRangeTracker, LexicographicKeyRangeTracker]:
   354      """Returns a RangeTracker for a given position range depending on type.
   355  
   356      Args:
   357        start_position: starting position of the range. If 'None' default start
   358                        position of the source must be used.
   359        stop_position:  ending position of the range. If 'None' default stop
   360                        position of the source must be used.
   361      Returns:
   362        a ``_ObjectIdRangeTracker``, ``OffsetRangeTracker``
   363        or ``LexicographicKeyRangeTracker`` depending on the given position range.
   364      """
   365      start_position, stop_position = self._replace_none_positions(
   366        start_position, stop_position
   367      )
   368  
   369      if isinstance(start_position, ObjectId):
   370        return _ObjectIdRangeTracker(start_position, stop_position)
   371  
   372      if isinstance(start_position, int):
   373        return OffsetRangeTracker(start_position, stop_position)
   374  
   375      if isinstance(start_position, str):
   376        return LexicographicKeyRangeTracker(start_position, stop_position)
   377  
   378      raise NotImplementedError(
   379          f"RangeTracker for {type(start_position)} not implemented!")
   380  
   381    def read(self, range_tracker):
   382      """Returns an iterator that reads data from the source.
   383  
   384      The returned set of data must respect the boundaries defined by the given
   385      ``RangeTracker`` object. For example:
   386  
   387        * Returned set of data must be for the range
   388          ``[range_tracker.start_position, range_tracker.stop_position)``. Note
   389          that a source may decide to return records that start after
   390          ``range_tracker.stop_position``. See documentation in class
   391          ``RangeTracker`` for more details. Also, note that framework might
   392          invoke ``range_tracker.try_split()`` to perform dynamic split
   393          operations. range_tracker.stop_position may be updated
   394          dynamically due to successful dynamic split operations.
   395        * Method ``range_tracker.try_split()`` must be invoked for every record
   396          that starts at a split point.
   397        * Method ``range_tracker.record_current_position()`` may be invoked for
   398          records that do not start at split points.
   399  
   400      Args:
   401        range_tracker: a ``RangeTracker`` whose boundaries must be respected
   402                       when reading data from the source. A runner that reads this
   403                       source muss pass a ``RangeTracker`` object that is not
   404                       ``None``.
   405      Returns:
   406        an iterator of data read by the source.
   407      """
   408      with MongoClient(self.uri, **self.spec) as client:
   409        all_filters = self._merge_id_filter(
   410            range_tracker.start_position(), range_tracker.stop_position())
   411        docs_cursor = (
   412            client[self.db][self.coll].find(
   413                filter=all_filters,
   414                projection=self.projection).sort([("_id", ASCENDING)]))
   415        for doc in docs_cursor:
   416          if not range_tracker.try_claim(doc["_id"]):
   417            return
   418          yield doc
   419  
   420    def display_data(self):
   421      """Returns the display data associated to a pipeline component."""
   422      res = super().display_data()
   423      res["database"] = self.db
   424      res["collection"] = self.coll
   425      res["filter"] = json.dumps(self.filter, default=json_util.default)
   426      res["projection"] = str(self.projection)
   427      res["bucket_auto"] = self.bucket_auto
   428      return res
   429  
   430    @staticmethod
   431    def _range_is_not_splittable(
   432        start_pos: Union[int, str, ObjectId],
   433        end_pos: Union[int, str, ObjectId],
   434    ):
   435      """Return `True` if splitting range doesn't make sense
   436      (single document is not splittable),
   437      Return `False` otherwise.
   438      """
   439      return ((
   440          isinstance(start_pos, ObjectId) and
   441          start_pos >= _ObjectIdHelper.increment_id(end_pos, -1)) or
   442              (isinstance(start_pos, int) and start_pos >= end_pos - 1) or
   443              (isinstance(start_pos, str) and start_pos >= end_pos))
   444  
   445    def _get_split_keys(
   446        self,
   447        desired_chunk_size_in_mb: int,
   448        start_pos: Union[int, str, ObjectId],
   449        end_pos: Union[int, str, ObjectId],
   450    ):
   451      """Calls MongoDB `splitVector` command
   452      to get document ids at split position.
   453      """
   454      # single document not splittable
   455      if self._range_is_not_splittable(start_pos, end_pos):
   456        return []
   457  
   458      with MongoClient(self.uri, **self.spec) as client:
   459        name_space = "%s.%s" % (self.db, self.coll)
   460        return client[self.db].command(
   461          "splitVector",
   462          name_space,
   463          keyPattern={"_id": 1},  # Ascending index
   464          min={"_id": start_pos},
   465          max={"_id": end_pos},
   466          maxChunkSize=desired_chunk_size_in_mb,
   467        )["splitKeys"]
   468  
   469    def _get_auto_buckets(
   470        self,
   471        desired_chunk_size_in_mb: int,
   472        start_pos: Union[int, str, ObjectId],
   473        end_pos: Union[int, str, ObjectId],
   474        is_initial_split: bool,
   475    ) -> list:
   476      """Use MongoDB `$bucketAuto` aggregation to split collection into bundles
   477      instead of `splitVector` command, which does not work with MongoDB Atlas.
   478      """
   479      # single document not splittable
   480      if self._range_is_not_splittable(start_pos, end_pos):
   481        return []
   482  
   483      if is_initial_split and not self.filter:
   484        # total collection size in MB
   485        size_in_mb = self.estimate_size() / float(1 << 20)
   486      else:
   487        # size of documents within start/end id range and possibly filtered
   488        documents_count = self._count_id_range(start_pos, end_pos)
   489        avg_document_size = self._estimate_average_document_size()
   490        size_in_mb = documents_count * avg_document_size / float(1 << 20)
   491  
   492      if size_in_mb == 0:
   493        # no documents not splittable (maybe a result of filtering)
   494        return []
   495  
   496      bucket_count = math.ceil(size_in_mb / desired_chunk_size_in_mb)
   497      with beam.io.mongodbio.MongoClient(self.uri, **self.spec) as client:
   498        pipeline = [
   499            {
   500                # filter by positions and by the custom filter if any
   501                "$match": self._merge_id_filter(start_pos, end_pos)
   502            },
   503            {
   504                "$bucketAuto": {
   505                    "groupBy": "$_id", "buckets": bucket_count
   506                }
   507            },
   508        ]
   509        buckets = list(
   510            # Use `allowDiskUse` option to avoid aggregation limit of 100 Mb RAM
   511            client[self.db][self.coll].aggregate(pipeline, allowDiskUse=True))
   512        if buckets:
   513          buckets[-1]["_id"]["max"] = end_pos
   514  
   515        return buckets
   516  
   517    def _merge_id_filter(
   518        self,
   519        start_position: Union[int, str, bytes, ObjectId],
   520        stop_position: Union[int, str, bytes, ObjectId] = None,
   521    ) -> dict:
   522      """Merge the default filter (if any) with refined _id field range
   523      of range_tracker.
   524      $gte specifies start position (inclusive)
   525      and $lt specifies the end position (exclusive),
   526      see more at
   527      https://docs.mongodb.com/manual/reference/operator/query/gte/ and
   528      https://docs.mongodb.com/manual/reference/operator/query/lt/
   529      """
   530  
   531      if stop_position is None:
   532        id_filter = {"_id": {"$gte": start_position}}
   533      else:
   534        id_filter = {"_id": {"$gte": start_position, "$lt": stop_position}}
   535  
   536      if self.filter:
   537        all_filters = {
   538            # see more at
   539            # https://docs.mongodb.com/manual/reference/operator/query/and/
   540            "$and": [self.filter.copy(), id_filter]
   541        }
   542      else:
   543        all_filters = id_filter
   544  
   545      return all_filters
   546  
   547    def _get_head_document_id(self, sort_order):
   548      with MongoClient(self.uri, **self.spec) as client:
   549        cursor = (
   550            client[self.db][self.coll].find(filter={}, projection=[]).sort([
   551                ("_id", sort_order)
   552            ]).limit(1))
   553        try:
   554          return cursor[0]["_id"]
   555  
   556        except IndexError:
   557          raise ValueError("Empty Mongodb collection")
   558  
   559    def _replace_none_positions(self, start_position, stop_position):
   560  
   561      if start_position is None:
   562        start_position = self._get_head_document_id(ASCENDING)
   563      if stop_position is None:
   564        last_doc_id = self._get_head_document_id(DESCENDING)
   565        # increment last doc id binary value by 1 to make sure the last document
   566        # is not excluded
   567        if isinstance(last_doc_id, ObjectId):
   568          stop_position = _ObjectIdHelper.increment_id(last_doc_id, 1)
   569        elif isinstance(last_doc_id, int):
   570          stop_position = last_doc_id + 1
   571        elif isinstance(last_doc_id, str):
   572          stop_position = last_doc_id + '\x00'
   573  
   574      return start_position, stop_position
   575  
   576    def _count_id_range(self, start_position, stop_position):
   577      """Number of documents between start_position (inclusive)
   578      and stop_position (exclusive), respecting the custom filter if any.
   579      """
   580      with MongoClient(self.uri, **self.spec) as client:
   581        return client[self.db][self.coll].count_documents(
   582            filter=self._merge_id_filter(start_position, stop_position))
   583  
   584  
   585  class _ObjectIdHelper:
   586    """A Utility class to manipulate bson object ids."""
   587    @classmethod
   588    def id_to_int(cls, _id: Union[int, ObjectId]) -> int:
   589      """
   590      Args:
   591        _id: ObjectId required for each MongoDB document _id field.
   592  
   593      Returns: Converted integer value of ObjectId's 12 bytes binary value.
   594      """
   595      if isinstance(_id, int):
   596        return _id
   597  
   598      # converts object id binary to integer
   599      # id object is bytes type with size of 12
   600      ints = struct.unpack(">III", _id.binary)
   601      return (ints[0] << 64) + (ints[1] << 32) + ints[2]
   602  
   603    @classmethod
   604    def int_to_id(cls, number):
   605      """
   606      Args:
   607        number(int): The integer value to be used to convert to ObjectId.
   608  
   609      Returns: The ObjectId that has the 12 bytes binary converted from the
   610        integer value.
   611      """
   612      # converts integer value to object id. Int value should be less than
   613      # (2 ^ 96) so it can be convert to 12 bytes required by object id.
   614      if number < 0 or number >= (1 << 96):
   615        raise ValueError("number value must be within [0, %s)" % (1 << 96))
   616      ints = [
   617          (number & 0xFFFFFFFF0000000000000000) >> 64,
   618          (number & 0x00000000FFFFFFFF00000000) >> 32,
   619          number & 0x0000000000000000FFFFFFFF,
   620      ]
   621  
   622      number_bytes = struct.pack(">III", *ints)
   623      return ObjectId(number_bytes)
   624  
   625    @classmethod
   626    def increment_id(
   627        cls,
   628        _id: ObjectId,
   629        inc: int,
   630    ) -> ObjectId:
   631      """
   632      Increment object_id binary value by inc value and return new object id.
   633  
   634      Args:
   635        _id: The `_id` to change.
   636        inc(int): The incremental int value to be added to `_id`.
   637  
   638      Returns:
   639          `_id` incremented by `inc` value
   640      """
   641      id_number = _ObjectIdHelper.id_to_int(_id)
   642      new_number = id_number + inc
   643      if new_number < 0 or new_number >= (1 << 96):
   644        raise ValueError(
   645            "invalid incremental, inc value must be within ["
   646            "%s, %s)" % (0 - id_number, 1 << 96 - id_number))
   647      return _ObjectIdHelper.int_to_id(new_number)
   648  
   649  
   650  class WriteToMongoDB(PTransform):
   651    """WriteToMongoDB is a ``PTransform`` that writes a ``PCollection`` of
   652    mongodb document to the configured MongoDB server.
   653  
   654    In order to make the document writes idempotent so that the bundles are
   655    retry-able without creating duplicates, the PTransform added 2 transformations
   656    before final write stage:
   657    a ``GenerateId`` transform and a ``Reshuffle`` transform.::
   658  
   659                    -----------------------------------------------
   660      Pipeline -->  |GenerateId --> Reshuffle --> WriteToMongoSink|
   661                    -----------------------------------------------
   662                                    (WriteToMongoDB)
   663  
   664    The ``GenerateId`` transform adds a random and unique*_id* field to the
   665    documents if they don't already have one, it uses the same format as MongoDB
   666    default. The ``Reshuffle`` transform makes sure that no fusion happens between
   667    ``GenerateId`` and the final write stage transform,so that the set of
   668    documents and their unique IDs are not regenerated if final write step is
   669    retried due to a failure. This prevents duplicate writes of the same document
   670    with different unique IDs.
   671  
   672    """
   673    def __init__(
   674        self,
   675        uri="mongodb://localhost:27017",
   676        db=None,
   677        coll=None,
   678        batch_size=100,
   679        extra_client_params=None,
   680    ):
   681      """
   682  
   683      Args:
   684        uri (str): The MongoDB connection string following the URI format
   685        db (str): The MongoDB database name
   686        coll (str): The MongoDB collection name
   687        batch_size(int): Number of documents per bulk_write to  MongoDB,
   688          default to 100
   689        extra_client_params(dict): Optional `MongoClient
   690         <https://api.mongodb.com/python/current/api/pymongo/mongo_client.html>`_
   691         parameters as keyword arguments
   692  
   693      Returns:
   694        :class:`~apache_beam.transforms.ptransform.PTransform`
   695  
   696      """
   697      if extra_client_params is None:
   698        extra_client_params = {}
   699      if not isinstance(db, str):
   700        raise ValueError("WriteToMongoDB db param must be specified as a string")
   701      if not isinstance(coll, str):
   702        raise ValueError(
   703            "WriteToMongoDB coll param must be specified as a string")
   704      self._uri = uri
   705      self._db = db
   706      self._coll = coll
   707      self._batch_size = batch_size
   708      self._spec = extra_client_params
   709  
   710    def expand(self, pcoll):
   711      return (
   712          pcoll
   713          | beam.ParDo(_GenerateObjectIdFn())
   714          | Reshuffle()
   715          | beam.ParDo(
   716              _WriteMongoFn(
   717                  self._uri, self._db, self._coll, self._batch_size, self._spec)))
   718  
   719  
   720  class _GenerateObjectIdFn(DoFn):
   721    def process(self, element, *args, **kwargs):
   722      # if _id field already exist we keep it as it is, otherwise the ptransform
   723      # generates a new _id field to achieve idempotent write to mongodb.
   724      if "_id" not in element:
   725        # object.ObjectId() generates a unique identifier that follows mongodb
   726        # default format, if _id is not present in document, mongodb server
   727        # generates it with this same function upon write. However the
   728        # uniqueness of generated id may not be guaranteed if the work load are
   729        # distributed across too many processes. See more on the ObjectId format
   730        # https://docs.mongodb.com/manual/reference/bson-types/#objectid.
   731        element["_id"] = objectid.ObjectId()
   732  
   733      yield element
   734  
   735  
   736  class _WriteMongoFn(DoFn):
   737    def __init__(
   738        self, uri=None, db=None, coll=None, batch_size=100, extra_params=None):
   739      if extra_params is None:
   740        extra_params = {}
   741      self.uri = uri
   742      self.db = db
   743      self.coll = coll
   744      self.spec = extra_params
   745      self.batch_size = batch_size
   746      self.batch = []
   747  
   748    def finish_bundle(self):
   749      self._flush()
   750  
   751    def process(self, element, *args, **kwargs):
   752      self.batch.append(element)
   753      if len(self.batch) >= self.batch_size:
   754        self._flush()
   755  
   756    def _flush(self):
   757      if len(self.batch) == 0:
   758        return
   759      with _MongoSink(self.uri, self.db, self.coll, self.spec) as sink:
   760        sink.write(self.batch)
   761        self.batch = []
   762  
   763    def display_data(self):
   764      res = super().display_data()
   765      res["database"] = self.db
   766      res["collection"] = self.coll
   767      res["batch_size"] = self.batch_size
   768      return res
   769  
   770  
   771  class _MongoSink:
   772    def __init__(self, uri=None, db=None, coll=None, extra_params=None):
   773      if extra_params is None:
   774        extra_params = {}
   775      self.uri = uri
   776      self.db = db
   777      self.coll = coll
   778      self.spec = extra_params
   779      self.client = None
   780  
   781    def write(self, documents):
   782      if self.client is None:
   783        self.client = MongoClient(host=self.uri, **self.spec)
   784      requests = []
   785      for doc in documents:
   786        # match document based on _id field, if not found in current collection,
   787        # insert new one, otherwise overwrite it.
   788        requests.append(
   789            ReplaceOne(
   790                filter={"_id": doc.get("_id", None)},
   791                replacement=doc,
   792                upsert=True))
   793      resp = self.client[self.db][self.coll].bulk_write(requests)
   794      _LOGGER.debug(
   795          "BulkWrite to MongoDB result in nModified:%d, nUpserted:%d, "
   796          "nMatched:%d, Errors:%s" % (
   797              resp.modified_count,
   798              resp.upserted_count,
   799              resp.matched_count,
   800              resp.bulk_api_result.get("writeErrors"),
   801          ))
   802  
   803    def __enter__(self):
   804      if self.client is None:
   805        self.client = MongoClient(host=self.uri, **self.spec)
   806      return self
   807  
   808    def __exit__(self, exc_type, exc_val, exc_tb):
   809      if self.client is not None:
   810        self.client.close()