github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/mongodbio.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """This module implements IO classes to read and write data on MongoDB. 19 20 21 Read from MongoDB 22 ----------------- 23 :class:`ReadFromMongoDB` is a ``PTransform`` that reads from a configured 24 MongoDB source and returns a ``PCollection`` of dict representing MongoDB 25 documents. 26 To configure MongoDB source, the URI to connect to MongoDB server, database 27 name, collection name needs to be provided. 28 29 Example usage:: 30 31 pipeline | ReadFromMongoDB(uri='mongodb://localhost:27017', 32 db='testdb', 33 coll='input') 34 35 To read from MongoDB Atlas, use ``bucket_auto`` option to enable 36 ``@bucketAuto`` MongoDB aggregation instead of ``splitVector`` 37 command which is a high-privilege function that cannot be assigned 38 to any user in Atlas. 39 40 Example usage:: 41 42 pipeline | ReadFromMongoDB(uri='mongodb+srv://user:pwd@cluster0.mongodb.net', 43 db='testdb', 44 coll='input', 45 bucket_auto=True) 46 47 48 Write to MongoDB: 49 ----------------- 50 :class:`WriteToMongoDB` is a ``PTransform`` that writes MongoDB documents to 51 configured sink, and the write is conducted through a mongodb bulk_write of 52 ``ReplaceOne`` operations. If the document's _id field already existed in the 53 MongoDB collection, it results in an overwrite, otherwise, a new document 54 will be inserted. 55 56 Example usage:: 57 58 pipeline | WriteToMongoDB(uri='mongodb://localhost:27017', 59 db='testdb', 60 coll='output', 61 batch_size=10) 62 63 64 No backward compatibility guarantees. Everything in this module is experimental. 65 """ 66 67 # pytype: skip-file 68 69 import itertools 70 import json 71 import logging 72 import math 73 import struct 74 from typing import Union 75 76 import apache_beam as beam 77 from apache_beam.io import iobase 78 from apache_beam.io.range_trackers import LexicographicKeyRangeTracker 79 from apache_beam.io.range_trackers import OffsetRangeTracker 80 from apache_beam.io.range_trackers import OrderedPositionRangeTracker 81 from apache_beam.transforms import DoFn 82 from apache_beam.transforms import PTransform 83 from apache_beam.transforms import Reshuffle 84 85 _LOGGER = logging.getLogger(__name__) 86 87 try: 88 # Mongodb has its own bundled bson, which is not compatible with bson package. 89 # (https://github.com/py-bson/bson/issues/82). Try to import objectid and if 90 # it fails because bson package is installed, MongoDB IO will not work but at 91 # least rest of the SDK will work. 92 from bson import json_util 93 from bson import objectid 94 from bson.objectid import ObjectId 95 96 # pymongo also internally depends on bson. 97 from pymongo import ASCENDING 98 from pymongo import DESCENDING 99 from pymongo import MongoClient 100 from pymongo import ReplaceOne 101 except ImportError: 102 objectid = None 103 json_util = None 104 ObjectId = None 105 ASCENDING = 1 106 DESCENDING = -1 107 MongoClient = None 108 ReplaceOne = None 109 _LOGGER.warning("Could not find a compatible bson package.") 110 111 __all__ = ["ReadFromMongoDB", "WriteToMongoDB"] 112 113 114 class ReadFromMongoDB(PTransform): 115 """A ``PTransform`` to read MongoDB documents into a ``PCollection``.""" 116 def __init__( 117 self, 118 uri="mongodb://localhost:27017", 119 db=None, 120 coll=None, 121 filter=None, 122 projection=None, 123 extra_client_params=None, 124 bucket_auto=False, 125 ): 126 """Initialize a :class:`ReadFromMongoDB` 127 128 Args: 129 uri (str): The MongoDB connection string following the URI format. 130 db (str): The MongoDB database name. 131 coll (str): The MongoDB collection name. 132 filter: A `bson.SON 133 <https://api.mongodb.com/python/current/api/bson/son.html>`_ object 134 specifying elements which must be present for a document to be included 135 in the result set. 136 projection: A list of field names that should be returned in the result 137 set or a dict specifying the fields to include or exclude. 138 extra_client_params(dict): Optional `MongoClient 139 <https://api.mongodb.com/python/current/api/pymongo/mongo_client.html>`_ 140 parameters. 141 bucket_auto (bool): If :data:`True`, use MongoDB `$bucketAuto` aggregation 142 to split collection into bundles instead of `splitVector` command, 143 which does not work with MongoDB Atlas. 144 If :data:`False` (the default), use `splitVector` command for bundling. 145 146 Returns: 147 :class:`~apache_beam.transforms.ptransform.PTransform` 148 """ 149 if extra_client_params is None: 150 extra_client_params = {} 151 if not isinstance(db, str): 152 raise ValueError("ReadFromMongDB db param must be specified as a string") 153 if not isinstance(coll, str): 154 raise ValueError( 155 "ReadFromMongDB coll param must be specified as a string") 156 self._mongo_source = _BoundedMongoSource( 157 uri=uri, 158 db=db, 159 coll=coll, 160 filter=filter, 161 projection=projection, 162 extra_client_params=extra_client_params, 163 bucket_auto=bucket_auto, 164 ) 165 166 def expand(self, pcoll): 167 return pcoll | iobase.Read(self._mongo_source) 168 169 170 class _ObjectIdRangeTracker(OrderedPositionRangeTracker): 171 """RangeTracker for tracking mongodb _id of bson ObjectId type.""" 172 def position_to_fraction( 173 self, 174 pos: ObjectId, 175 start: ObjectId, 176 end: ObjectId, 177 ): 178 """Returns the fraction of keys in the range [start, end) that 179 are less than the given key. 180 """ 181 pos_number = _ObjectIdHelper.id_to_int(pos) 182 start_number = _ObjectIdHelper.id_to_int(start) 183 end_number = _ObjectIdHelper.id_to_int(end) 184 return (pos_number - start_number) / (end_number - start_number) 185 186 def fraction_to_position( 187 self, 188 fraction: float, 189 start: ObjectId, 190 end: ObjectId, 191 ): 192 """Converts a fraction between 0 and 1 193 to a position between start and end. 194 """ 195 start_number = _ObjectIdHelper.id_to_int(start) 196 end_number = _ObjectIdHelper.id_to_int(end) 197 total = end_number - start_number 198 pos = int(total * fraction + start_number) 199 # make sure split position is larger than start position and smaller than 200 # end position. 201 if pos <= start_number: 202 return _ObjectIdHelper.increment_id(start, 1) 203 204 if pos >= end_number: 205 return _ObjectIdHelper.increment_id(end, -1) 206 207 return _ObjectIdHelper.int_to_id(pos) 208 209 210 class _BoundedMongoSource(iobase.BoundedSource): 211 """A MongoDB source that reads a finite amount of input records. 212 213 This class defines following operations which can be used to read 214 MongoDB source efficiently. 215 216 * Size estimation - method ``estimate_size()`` may return an accurate 217 estimation in bytes for the size of the source. 218 * Splitting into bundles of a given size - method ``split()`` can be used to 219 split the source into a set of sub-sources (bundles) based on a desired 220 bundle size. 221 * Getting a RangeTracker - method ``get_range_tracker()`` should return a 222 ``RangeTracker`` object for a given position range for the position type 223 of the records returned by the source. 224 * Reading the data - method ``read()`` can be used to read data from the 225 source while respecting the boundaries defined by a given 226 ``RangeTracker``. 227 228 A runner will perform reading the source in two steps. 229 230 (1) Method ``get_range_tracker()`` will be invoked with start and end 231 positions to obtain a ``RangeTracker`` for the range of positions the 232 runner intends to read. Source must define a default initial start and end 233 position range. These positions must be used if the start and/or end 234 positions passed to the method ``get_range_tracker()`` are ``None`` 235 (2) Method read() will be invoked with the ``RangeTracker`` obtained in the 236 previous step. 237 238 **Mutability** 239 240 A ``_BoundedMongoSource`` object should not be mutated while 241 its methods (for example, ``read()``) are being invoked by a runner. Runner 242 implementations may invoke methods of ``_BoundedMongoSource`` objects through 243 multi-threaded and/or reentrant execution modes. 244 """ 245 def __init__( 246 self, 247 uri=None, 248 db=None, 249 coll=None, 250 filter=None, 251 projection=None, 252 extra_client_params=None, 253 bucket_auto=False, 254 ): 255 if extra_client_params is None: 256 extra_client_params = {} 257 if filter is None: 258 filter = {} 259 self.uri = uri 260 self.db = db 261 self.coll = coll 262 self.filter = filter 263 self.projection = projection 264 self.spec = extra_client_params 265 self.bucket_auto = bucket_auto 266 267 def estimate_size(self): 268 with MongoClient(self.uri, **self.spec) as client: 269 return client[self.db].command("collstats", self.coll).get("size") 270 271 def _estimate_average_document_size(self): 272 with MongoClient(self.uri, **self.spec) as client: 273 return client[self.db].command("collstats", self.coll).get("avgObjSize") 274 275 def split( 276 self, 277 desired_bundle_size: int, 278 start_position: Union[int, str, bytes, ObjectId] = None, 279 stop_position: Union[int, str, bytes, ObjectId] = None, 280 ): 281 """Splits the source into a set of bundles. 282 283 Bundles should be approximately of size ``desired_bundle_size`` bytes. 284 285 Args: 286 desired_bundle_size: the desired size (in bytes) of the bundles returned. 287 start_position: if specified the given position must be used as the 288 starting position of the first bundle. 289 stop_position: if specified the given position must be used as the ending 290 position of the last bundle. 291 Returns: 292 an iterator of objects of type 'SourceBundle' that gives information about 293 the generated bundles. 294 """ 295 296 desired_bundle_size_in_mb = desired_bundle_size // 1024 // 1024 297 298 # for desired bundle size, if desired chunk size smaller than 1mb, use 299 # MongoDB default split size of 1mb. 300 desired_bundle_size_in_mb = max(desired_bundle_size_in_mb, 1) 301 302 is_initial_split = start_position is None and stop_position is None 303 start_position, stop_position = self._replace_none_positions( 304 start_position, stop_position 305 ) 306 307 if self.bucket_auto: 308 # Use $bucketAuto for bundling 309 split_keys = [] 310 weights = [] 311 for bucket in self._get_auto_buckets( 312 desired_bundle_size_in_mb, 313 start_position, 314 stop_position, 315 is_initial_split, 316 ): 317 split_keys.append({"_id": bucket["_id"]["max"]}) 318 weights.append(bucket["count"]) 319 else: 320 # Use splitVector for bundling 321 split_keys = self._get_split_keys( 322 desired_bundle_size_in_mb, start_position, stop_position) 323 weights = itertools.cycle((desired_bundle_size_in_mb, )) 324 325 bundle_start = start_position 326 for split_key_id, weight in zip(split_keys, weights): 327 if bundle_start >= stop_position: 328 break 329 bundle_end = min(stop_position, split_key_id["_id"]) 330 yield iobase.SourceBundle( 331 weight=weight, 332 source=self, 333 start_position=bundle_start, 334 stop_position=bundle_end, 335 ) 336 bundle_start = bundle_end 337 # add range of last split_key to stop_position 338 if bundle_start < stop_position: 339 # bucket_auto mode can come here if not split due to single document 340 weight = 1 if self.bucket_auto else desired_bundle_size_in_mb 341 yield iobase.SourceBundle( 342 weight=weight, 343 source=self, 344 start_position=bundle_start, 345 stop_position=stop_position, 346 ) 347 348 def get_range_tracker( 349 self, 350 start_position: Union[int, str, ObjectId] = None, 351 stop_position: Union[int, str, ObjectId] = None, 352 ) -> Union[ 353 _ObjectIdRangeTracker, OffsetRangeTracker, LexicographicKeyRangeTracker]: 354 """Returns a RangeTracker for a given position range depending on type. 355 356 Args: 357 start_position: starting position of the range. If 'None' default start 358 position of the source must be used. 359 stop_position: ending position of the range. If 'None' default stop 360 position of the source must be used. 361 Returns: 362 a ``_ObjectIdRangeTracker``, ``OffsetRangeTracker`` 363 or ``LexicographicKeyRangeTracker`` depending on the given position range. 364 """ 365 start_position, stop_position = self._replace_none_positions( 366 start_position, stop_position 367 ) 368 369 if isinstance(start_position, ObjectId): 370 return _ObjectIdRangeTracker(start_position, stop_position) 371 372 if isinstance(start_position, int): 373 return OffsetRangeTracker(start_position, stop_position) 374 375 if isinstance(start_position, str): 376 return LexicographicKeyRangeTracker(start_position, stop_position) 377 378 raise NotImplementedError( 379 f"RangeTracker for {type(start_position)} not implemented!") 380 381 def read(self, range_tracker): 382 """Returns an iterator that reads data from the source. 383 384 The returned set of data must respect the boundaries defined by the given 385 ``RangeTracker`` object. For example: 386 387 * Returned set of data must be for the range 388 ``[range_tracker.start_position, range_tracker.stop_position)``. Note 389 that a source may decide to return records that start after 390 ``range_tracker.stop_position``. See documentation in class 391 ``RangeTracker`` for more details. Also, note that framework might 392 invoke ``range_tracker.try_split()`` to perform dynamic split 393 operations. range_tracker.stop_position may be updated 394 dynamically due to successful dynamic split operations. 395 * Method ``range_tracker.try_split()`` must be invoked for every record 396 that starts at a split point. 397 * Method ``range_tracker.record_current_position()`` may be invoked for 398 records that do not start at split points. 399 400 Args: 401 range_tracker: a ``RangeTracker`` whose boundaries must be respected 402 when reading data from the source. A runner that reads this 403 source muss pass a ``RangeTracker`` object that is not 404 ``None``. 405 Returns: 406 an iterator of data read by the source. 407 """ 408 with MongoClient(self.uri, **self.spec) as client: 409 all_filters = self._merge_id_filter( 410 range_tracker.start_position(), range_tracker.stop_position()) 411 docs_cursor = ( 412 client[self.db][self.coll].find( 413 filter=all_filters, 414 projection=self.projection).sort([("_id", ASCENDING)])) 415 for doc in docs_cursor: 416 if not range_tracker.try_claim(doc["_id"]): 417 return 418 yield doc 419 420 def display_data(self): 421 """Returns the display data associated to a pipeline component.""" 422 res = super().display_data() 423 res["database"] = self.db 424 res["collection"] = self.coll 425 res["filter"] = json.dumps(self.filter, default=json_util.default) 426 res["projection"] = str(self.projection) 427 res["bucket_auto"] = self.bucket_auto 428 return res 429 430 @staticmethod 431 def _range_is_not_splittable( 432 start_pos: Union[int, str, ObjectId], 433 end_pos: Union[int, str, ObjectId], 434 ): 435 """Return `True` if splitting range doesn't make sense 436 (single document is not splittable), 437 Return `False` otherwise. 438 """ 439 return (( 440 isinstance(start_pos, ObjectId) and 441 start_pos >= _ObjectIdHelper.increment_id(end_pos, -1)) or 442 (isinstance(start_pos, int) and start_pos >= end_pos - 1) or 443 (isinstance(start_pos, str) and start_pos >= end_pos)) 444 445 def _get_split_keys( 446 self, 447 desired_chunk_size_in_mb: int, 448 start_pos: Union[int, str, ObjectId], 449 end_pos: Union[int, str, ObjectId], 450 ): 451 """Calls MongoDB `splitVector` command 452 to get document ids at split position. 453 """ 454 # single document not splittable 455 if self._range_is_not_splittable(start_pos, end_pos): 456 return [] 457 458 with MongoClient(self.uri, **self.spec) as client: 459 name_space = "%s.%s" % (self.db, self.coll) 460 return client[self.db].command( 461 "splitVector", 462 name_space, 463 keyPattern={"_id": 1}, # Ascending index 464 min={"_id": start_pos}, 465 max={"_id": end_pos}, 466 maxChunkSize=desired_chunk_size_in_mb, 467 )["splitKeys"] 468 469 def _get_auto_buckets( 470 self, 471 desired_chunk_size_in_mb: int, 472 start_pos: Union[int, str, ObjectId], 473 end_pos: Union[int, str, ObjectId], 474 is_initial_split: bool, 475 ) -> list: 476 """Use MongoDB `$bucketAuto` aggregation to split collection into bundles 477 instead of `splitVector` command, which does not work with MongoDB Atlas. 478 """ 479 # single document not splittable 480 if self._range_is_not_splittable(start_pos, end_pos): 481 return [] 482 483 if is_initial_split and not self.filter: 484 # total collection size in MB 485 size_in_mb = self.estimate_size() / float(1 << 20) 486 else: 487 # size of documents within start/end id range and possibly filtered 488 documents_count = self._count_id_range(start_pos, end_pos) 489 avg_document_size = self._estimate_average_document_size() 490 size_in_mb = documents_count * avg_document_size / float(1 << 20) 491 492 if size_in_mb == 0: 493 # no documents not splittable (maybe a result of filtering) 494 return [] 495 496 bucket_count = math.ceil(size_in_mb / desired_chunk_size_in_mb) 497 with beam.io.mongodbio.MongoClient(self.uri, **self.spec) as client: 498 pipeline = [ 499 { 500 # filter by positions and by the custom filter if any 501 "$match": self._merge_id_filter(start_pos, end_pos) 502 }, 503 { 504 "$bucketAuto": { 505 "groupBy": "$_id", "buckets": bucket_count 506 } 507 }, 508 ] 509 buckets = list( 510 # Use `allowDiskUse` option to avoid aggregation limit of 100 Mb RAM 511 client[self.db][self.coll].aggregate(pipeline, allowDiskUse=True)) 512 if buckets: 513 buckets[-1]["_id"]["max"] = end_pos 514 515 return buckets 516 517 def _merge_id_filter( 518 self, 519 start_position: Union[int, str, bytes, ObjectId], 520 stop_position: Union[int, str, bytes, ObjectId] = None, 521 ) -> dict: 522 """Merge the default filter (if any) with refined _id field range 523 of range_tracker. 524 $gte specifies start position (inclusive) 525 and $lt specifies the end position (exclusive), 526 see more at 527 https://docs.mongodb.com/manual/reference/operator/query/gte/ and 528 https://docs.mongodb.com/manual/reference/operator/query/lt/ 529 """ 530 531 if stop_position is None: 532 id_filter = {"_id": {"$gte": start_position}} 533 else: 534 id_filter = {"_id": {"$gte": start_position, "$lt": stop_position}} 535 536 if self.filter: 537 all_filters = { 538 # see more at 539 # https://docs.mongodb.com/manual/reference/operator/query/and/ 540 "$and": [self.filter.copy(), id_filter] 541 } 542 else: 543 all_filters = id_filter 544 545 return all_filters 546 547 def _get_head_document_id(self, sort_order): 548 with MongoClient(self.uri, **self.spec) as client: 549 cursor = ( 550 client[self.db][self.coll].find(filter={}, projection=[]).sort([ 551 ("_id", sort_order) 552 ]).limit(1)) 553 try: 554 return cursor[0]["_id"] 555 556 except IndexError: 557 raise ValueError("Empty Mongodb collection") 558 559 def _replace_none_positions(self, start_position, stop_position): 560 561 if start_position is None: 562 start_position = self._get_head_document_id(ASCENDING) 563 if stop_position is None: 564 last_doc_id = self._get_head_document_id(DESCENDING) 565 # increment last doc id binary value by 1 to make sure the last document 566 # is not excluded 567 if isinstance(last_doc_id, ObjectId): 568 stop_position = _ObjectIdHelper.increment_id(last_doc_id, 1) 569 elif isinstance(last_doc_id, int): 570 stop_position = last_doc_id + 1 571 elif isinstance(last_doc_id, str): 572 stop_position = last_doc_id + '\x00' 573 574 return start_position, stop_position 575 576 def _count_id_range(self, start_position, stop_position): 577 """Number of documents between start_position (inclusive) 578 and stop_position (exclusive), respecting the custom filter if any. 579 """ 580 with MongoClient(self.uri, **self.spec) as client: 581 return client[self.db][self.coll].count_documents( 582 filter=self._merge_id_filter(start_position, stop_position)) 583 584 585 class _ObjectIdHelper: 586 """A Utility class to manipulate bson object ids.""" 587 @classmethod 588 def id_to_int(cls, _id: Union[int, ObjectId]) -> int: 589 """ 590 Args: 591 _id: ObjectId required for each MongoDB document _id field. 592 593 Returns: Converted integer value of ObjectId's 12 bytes binary value. 594 """ 595 if isinstance(_id, int): 596 return _id 597 598 # converts object id binary to integer 599 # id object is bytes type with size of 12 600 ints = struct.unpack(">III", _id.binary) 601 return (ints[0] << 64) + (ints[1] << 32) + ints[2] 602 603 @classmethod 604 def int_to_id(cls, number): 605 """ 606 Args: 607 number(int): The integer value to be used to convert to ObjectId. 608 609 Returns: The ObjectId that has the 12 bytes binary converted from the 610 integer value. 611 """ 612 # converts integer value to object id. Int value should be less than 613 # (2 ^ 96) so it can be convert to 12 bytes required by object id. 614 if number < 0 or number >= (1 << 96): 615 raise ValueError("number value must be within [0, %s)" % (1 << 96)) 616 ints = [ 617 (number & 0xFFFFFFFF0000000000000000) >> 64, 618 (number & 0x00000000FFFFFFFF00000000) >> 32, 619 number & 0x0000000000000000FFFFFFFF, 620 ] 621 622 number_bytes = struct.pack(">III", *ints) 623 return ObjectId(number_bytes) 624 625 @classmethod 626 def increment_id( 627 cls, 628 _id: ObjectId, 629 inc: int, 630 ) -> ObjectId: 631 """ 632 Increment object_id binary value by inc value and return new object id. 633 634 Args: 635 _id: The `_id` to change. 636 inc(int): The incremental int value to be added to `_id`. 637 638 Returns: 639 `_id` incremented by `inc` value 640 """ 641 id_number = _ObjectIdHelper.id_to_int(_id) 642 new_number = id_number + inc 643 if new_number < 0 or new_number >= (1 << 96): 644 raise ValueError( 645 "invalid incremental, inc value must be within [" 646 "%s, %s)" % (0 - id_number, 1 << 96 - id_number)) 647 return _ObjectIdHelper.int_to_id(new_number) 648 649 650 class WriteToMongoDB(PTransform): 651 """WriteToMongoDB is a ``PTransform`` that writes a ``PCollection`` of 652 mongodb document to the configured MongoDB server. 653 654 In order to make the document writes idempotent so that the bundles are 655 retry-able without creating duplicates, the PTransform added 2 transformations 656 before final write stage: 657 a ``GenerateId`` transform and a ``Reshuffle`` transform.:: 658 659 ----------------------------------------------- 660 Pipeline --> |GenerateId --> Reshuffle --> WriteToMongoSink| 661 ----------------------------------------------- 662 (WriteToMongoDB) 663 664 The ``GenerateId`` transform adds a random and unique*_id* field to the 665 documents if they don't already have one, it uses the same format as MongoDB 666 default. The ``Reshuffle`` transform makes sure that no fusion happens between 667 ``GenerateId`` and the final write stage transform,so that the set of 668 documents and their unique IDs are not regenerated if final write step is 669 retried due to a failure. This prevents duplicate writes of the same document 670 with different unique IDs. 671 672 """ 673 def __init__( 674 self, 675 uri="mongodb://localhost:27017", 676 db=None, 677 coll=None, 678 batch_size=100, 679 extra_client_params=None, 680 ): 681 """ 682 683 Args: 684 uri (str): The MongoDB connection string following the URI format 685 db (str): The MongoDB database name 686 coll (str): The MongoDB collection name 687 batch_size(int): Number of documents per bulk_write to MongoDB, 688 default to 100 689 extra_client_params(dict): Optional `MongoClient 690 <https://api.mongodb.com/python/current/api/pymongo/mongo_client.html>`_ 691 parameters as keyword arguments 692 693 Returns: 694 :class:`~apache_beam.transforms.ptransform.PTransform` 695 696 """ 697 if extra_client_params is None: 698 extra_client_params = {} 699 if not isinstance(db, str): 700 raise ValueError("WriteToMongoDB db param must be specified as a string") 701 if not isinstance(coll, str): 702 raise ValueError( 703 "WriteToMongoDB coll param must be specified as a string") 704 self._uri = uri 705 self._db = db 706 self._coll = coll 707 self._batch_size = batch_size 708 self._spec = extra_client_params 709 710 def expand(self, pcoll): 711 return ( 712 pcoll 713 | beam.ParDo(_GenerateObjectIdFn()) 714 | Reshuffle() 715 | beam.ParDo( 716 _WriteMongoFn( 717 self._uri, self._db, self._coll, self._batch_size, self._spec))) 718 719 720 class _GenerateObjectIdFn(DoFn): 721 def process(self, element, *args, **kwargs): 722 # if _id field already exist we keep it as it is, otherwise the ptransform 723 # generates a new _id field to achieve idempotent write to mongodb. 724 if "_id" not in element: 725 # object.ObjectId() generates a unique identifier that follows mongodb 726 # default format, if _id is not present in document, mongodb server 727 # generates it with this same function upon write. However the 728 # uniqueness of generated id may not be guaranteed if the work load are 729 # distributed across too many processes. See more on the ObjectId format 730 # https://docs.mongodb.com/manual/reference/bson-types/#objectid. 731 element["_id"] = objectid.ObjectId() 732 733 yield element 734 735 736 class _WriteMongoFn(DoFn): 737 def __init__( 738 self, uri=None, db=None, coll=None, batch_size=100, extra_params=None): 739 if extra_params is None: 740 extra_params = {} 741 self.uri = uri 742 self.db = db 743 self.coll = coll 744 self.spec = extra_params 745 self.batch_size = batch_size 746 self.batch = [] 747 748 def finish_bundle(self): 749 self._flush() 750 751 def process(self, element, *args, **kwargs): 752 self.batch.append(element) 753 if len(self.batch) >= self.batch_size: 754 self._flush() 755 756 def _flush(self): 757 if len(self.batch) == 0: 758 return 759 with _MongoSink(self.uri, self.db, self.coll, self.spec) as sink: 760 sink.write(self.batch) 761 self.batch = [] 762 763 def display_data(self): 764 res = super().display_data() 765 res["database"] = self.db 766 res["collection"] = self.coll 767 res["batch_size"] = self.batch_size 768 return res 769 770 771 class _MongoSink: 772 def __init__(self, uri=None, db=None, coll=None, extra_params=None): 773 if extra_params is None: 774 extra_params = {} 775 self.uri = uri 776 self.db = db 777 self.coll = coll 778 self.spec = extra_params 779 self.client = None 780 781 def write(self, documents): 782 if self.client is None: 783 self.client = MongoClient(host=self.uri, **self.spec) 784 requests = [] 785 for doc in documents: 786 # match document based on _id field, if not found in current collection, 787 # insert new one, otherwise overwrite it. 788 requests.append( 789 ReplaceOne( 790 filter={"_id": doc.get("_id", None)}, 791 replacement=doc, 792 upsert=True)) 793 resp = self.client[self.db][self.coll].bulk_write(requests) 794 _LOGGER.debug( 795 "BulkWrite to MongoDB result in nModified:%d, nUpserted:%d, " 796 "nMatched:%d, Errors:%s" % ( 797 resp.modified_count, 798 resp.upserted_count, 799 resp.matched_count, 800 resp.bulk_api_result.get("writeErrors"), 801 )) 802 803 def __enter__(self): 804 if self.client is None: 805 self.client = MongoClient(host=self.uri, **self.spec) 806 return self 807 808 def __exit__(self, exc_type, exc_val, exc_tb): 809 if self.client is not None: 810 self.client.close()