github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/examples/dataframe/taxiride_it_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """End-to-end tests for the taxiride examples."""
    19  
    20  # pytype: skip-file
    21  
    22  import logging
    23  import os
    24  import unittest
    25  import uuid
    26  
    27  import pandas as pd
    28  import pytest
    29  
    30  from apache_beam.examples.dataframe import taxiride
    31  from apache_beam.io.filesystems import FileSystems
    32  from apache_beam.options.pipeline_options import WorkerOptions
    33  from apache_beam.testing.test_pipeline import TestPipeline
    34  
    35  
    36  class TaxirideIT(unittest.TestCase):
    37    def setUp(self):
    38      self.test_pipeline = TestPipeline(is_integration_test=True)
    39      self.outdir = (
    40          self.test_pipeline.get_option('temp_location') + '/taxiride_it-' +
    41          str(uuid.uuid4()))
    42      self.output_path = os.path.join(self.outdir, 'output.csv')
    43  
    44    def tearDown(self):
    45      FileSystems.delete([self.outdir + '/'])
    46  
    47    @pytest.mark.it_postcommit
    48    def test_aggregation(self):
    49      taxiride.run_aggregation_pipeline(
    50          self.test_pipeline,
    51          'gs://apache-beam-samples/nyc_taxi/2018/*.csv',
    52          self.output_path)
    53  
    54      # Verify
    55      expected = pd.read_csv(
    56          os.path.join(
    57              os.path.dirname(__file__),
    58              'data',
    59              'taxiride_2018_aggregation_truth.csv'),
    60          comment='#')
    61      expected = expected.sort_values('DOLocationID').reset_index(drop=True)
    62  
    63      def read_csv(path):
    64        with FileSystems.open(path) as fp:
    65          return pd.read_csv(fp)
    66  
    67      result = pd.concat(
    68          read_csv(metadata.path) for metadata in FileSystems.match(
    69              [f'{self.output_path}*'])[0].metadata_list)
    70      result = result.sort_values('DOLocationID').reset_index(drop=True)
    71  
    72      pd.testing.assert_frame_equal(expected, result)
    73  
    74    @pytest.mark.it_postcommit
    75    def test_enrich(self):
    76      # Standard workers OOM with the enrich pipeline
    77      self.test_pipeline.get_pipeline_options().view_as(
    78          WorkerOptions).machine_type = 'e2-highmem-2'
    79  
    80      taxiride.run_enrich_pipeline(
    81          self.test_pipeline,
    82          'gs://apache-beam-samples/nyc_taxi/2018/*.csv',
    83          self.output_path)
    84  
    85      # Verify
    86      expected = pd.read_csv(
    87          os.path.join(
    88              os.path.dirname(__file__), 'data',
    89              'taxiride_2018_enrich_truth.csv'),
    90          comment='#')
    91      expected = expected.sort_values('Borough').reset_index(drop=True)
    92  
    93      def read_csv(path):
    94        with FileSystems.open(path) as fp:
    95          return pd.read_csv(fp)
    96  
    97      result = pd.concat(
    98          read_csv(metadata.path) for metadata in FileSystems.match(
    99              [f'{self.output_path}*'])[0].metadata_list)
   100      result = result.sort_values('Borough').reset_index(drop=True)
   101  
   102      pd.testing.assert_frame_equal(expected, result)
   103  
   104  
   105  if __name__ == '__main__':
   106    logging.getLogger().setLevel(logging.DEBUG)
   107    unittest.main()