github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/mongodbio_it_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  # pytype: skip-file
    19  
    20  import argparse
    21  import logging
    22  import time
    23  
    24  from pymongo import MongoClient
    25  
    26  import apache_beam as beam
    27  from apache_beam.options.pipeline_options import PipelineOptions
    28  from apache_beam.testing.test_pipeline import TestPipeline
    29  from apache_beam.testing.util import assert_that
    30  from apache_beam.testing.util import equal_to
    31  
    32  _LOGGER = logging.getLogger(__name__)
    33  
    34  
    35  class GenerateDocs(beam.DoFn):
    36    def process(self, num_docs, *args, **kwargs):
    37      for i in range(num_docs):
    38        yield {'number': i, 'number_mod_2': i % 2, 'number_mod_3': i % 3}
    39  
    40  
    41  def run(argv=None):
    42    default_db = 'beam_mongodbio_it_db'
    43    default_coll = 'integration_test_%d' % time.time()
    44    parser = argparse.ArgumentParser()
    45    parser.add_argument(
    46        '--mongo_uri',
    47        default='mongodb://localhost:27017',
    48        help='mongo uri string for connection')
    49    parser.add_argument(
    50        '--mongo_db', default=default_db, help='mongo uri string for connection')
    51    parser.add_argument(
    52        '--mongo_coll',
    53        default=default_coll,
    54        help='mongo uri string for connection')
    55    parser.add_argument(
    56        '--num_documents',
    57        default=100000,
    58        help='The expected number of documents to be generated '
    59        'for write or read',
    60        type=int)
    61    parser.add_argument(
    62        '--batch_size',
    63        default=10000,
    64        type=int,
    65        help=('batch size for writing to mongodb'))
    66    known_args, pipeline_args = parser.parse_known_args(argv)
    67  
    68    # Test Write to MongoDB
    69    with TestPipeline(options=PipelineOptions(pipeline_args)) as p:
    70      start_time = time.time()
    71      _LOGGER.info('Writing %d documents to mongodb', known_args.num_documents)
    72  
    73      _ = (
    74          p | beam.Create([known_args.num_documents])
    75          | 'Create documents' >> beam.ParDo(GenerateDocs())
    76          | 'WriteToMongoDB' >> beam.io.WriteToMongoDB(
    77              known_args.mongo_uri,
    78              known_args.mongo_db,
    79              known_args.mongo_coll,
    80              known_args.batch_size))
    81    elapsed = time.time() - start_time
    82    _LOGGER.info(
    83        'Writing %d documents to mongodb finished in %.3f seconds' %
    84        (known_args.num_documents, elapsed))
    85  
    86    # Test Read from MongoDB
    87    total_sum = sum(range(known_args.num_documents))
    88    mod_3_sum = sum(
    89        num for num in range(known_args.num_documents) if num % 3 == 0)
    90    mod_3_count = sum(
    91        1 for num in range(known_args.num_documents) if num % 3 == 0)
    92    # yapf: disable
    93    read_cases = [
    94        # (reader_params, expected)
    95        (
    96            {
    97                'projection': ['number']
    98            },
    99            {
   100                'number_sum': total_sum,
   101                'docs_count': known_args.num_documents
   102            }
   103        ),
   104        (
   105            {
   106                'filter': {'number_mod_3': 0},
   107                'projection': ['number']
   108            },
   109            {
   110                'number_sum': mod_3_sum,
   111                'docs_count': mod_3_count
   112            }
   113        ),
   114        (
   115            {
   116                'projection': ['number'],
   117                'bucket_auto': True
   118            },
   119            {
   120                'number_sum': total_sum,
   121                'docs_count': known_args.num_documents
   122            }
   123        ),
   124        (
   125            {
   126                'filter': {'number_mod_3': 0},
   127                'projection': ['number'],
   128                'bucket_auto': True
   129            },
   130            {
   131                'number_sum': mod_3_sum,
   132                'docs_count': mod_3_count
   133            }
   134        ),
   135    ]
   136    # yapf: enable
   137    for reader_params, expected in read_cases:
   138      with TestPipeline(options=PipelineOptions(pipeline_args)) as p:
   139        start_time = time.time()
   140        _LOGGER.info('=' * 80)
   141        _LOGGER.info(
   142            'Reading from mongodb %s:%s',
   143            known_args.mongo_db,
   144            known_args.mongo_coll)
   145        _LOGGER.info('reader params   : %s', reader_params)
   146        _LOGGER.info('expected results: %s', expected)
   147        docs = (
   148            p | 'ReadFromMongoDB' >> beam.io.ReadFromMongoDB(
   149                known_args.mongo_uri,
   150                known_args.mongo_db,
   151                known_args.mongo_coll,
   152                **reader_params)
   153            | 'Map' >> beam.Map(lambda doc: doc['number']))
   154        number_sum = (docs | 'Combine' >> beam.CombineGlobally(sum))
   155        docs_count = (docs | 'Count' >> beam.combiners.Count.Globally())
   156        r = ([number_sum, docs_count] | 'Flatten' >> beam.Flatten())
   157        assert_that(r, equal_to([expected['number_sum'], expected['docs_count']]))
   158  
   159      elapsed = time.time() - start_time
   160      _LOGGER.info(
   161          'Reading documents from mongodb finished in %.3f seconds', elapsed)
   162  
   163    # Clean-up
   164    with MongoClient(host=known_args.mongo_uri) as client:
   165      client.drop_database(known_args.mongo_db)
   166  
   167  
   168  if __name__ == "__main__":
   169    logging.getLogger().setLevel(logging.INFO)
   170    run()