github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/examples/dataframe/taxiride_test.py (about)

     1  # -*- coding: utf-8 -*-
     2  #
     3  # Licensed to the Apache Software Foundation (ASF) under one or more
     4  # contributor license agreements.  See the NOTICE file distributed with
     5  # this work for additional information regarding copyright ownership.
     6  # The ASF licenses this file to You under the Apache License, Version 2.0
     7  # (the "License"); you may not use this file except in compliance with
     8  # the License.  You may obtain a copy of the License at
     9  #
    10  #    http://www.apache.org/licenses/LICENSE-2.0
    11  #
    12  # Unless required by applicable law or agreed to in writing, software
    13  # distributed under the License is distributed on an "AS IS" BASIS,
    14  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    15  # See the License for the specific language governing permissions and
    16  # limitations under the License.
    17  #
    18  
    19  """Unit tests for the taxiride example pipelines."""
    20  
    21  # pytype: skip-file
    22  
    23  from __future__ import absolute_import
    24  
    25  import glob
    26  import logging
    27  import os
    28  import re
    29  import tempfile
    30  import unittest
    31  
    32  import pandas as pd
    33  
    34  import apache_beam as beam
    35  from apache_beam.examples.dataframe import taxiride
    36  from apache_beam.testing.util import open_shards
    37  
    38  
    39  class TaxiRideExampleTest(unittest.TestCase):
    40  
    41    # First 10 lines from gs://apache-beam-samples/nyc_taxi/misc/sample.csv
    42    # pylint: disable=line-too-long
    43    SAMPLE_RIDES = """VendorID,tpep_pickup_datetime,tpep_dropoff_datetime,passenger_count,trip_distance,RatecodeID,store_and_fwd_flag,PULocationID,DOLocationID,payment_type,fare_amount,extra,mta_tax,tip_amount,tolls_amount,improvement_surcharge,total_amount,congestion_surcharge
    44    1,2019-01-01 00:46:40,2019-01-01 00:53:20,1,1.50,1,N,151,239,1,7,0.5,0.5,1.65,0,0.3,9.95,
    45    1,2019-01-01 00:59:47,2019-01-01 01:18:59,1,2.60,1,N,239,246,1,14,0.5,0.5,1,0,0.3,16.3,
    46    2,2018-12-21 13:48:30,2018-12-21 13:52:40,3,.00,1,N,236,236,1,4.5,0.5,0.5,0,0,0.3,5.8,
    47    2,2018-11-28 15:52:25,2018-11-28 15:55:45,5,.00,1,N,193,193,2,3.5,0.5,0.5,0,0,0.3,7.55,
    48    2,2018-11-28 15:56:57,2018-11-28 15:58:33,5,.00,2,N,193,193,2,52,0,0.5,0,0,0.3,55.55,
    49    2,2018-11-28 16:25:49,2018-11-28 16:28:26,5,.00,1,N,193,193,2,3.5,0.5,0.5,0,5.76,0.3,13.31,
    50    2,2018-11-28 16:29:37,2018-11-28 16:33:43,5,.00,2,N,193,193,2,52,0,0.5,0,0,0.3,55.55,
    51    1,2019-01-01 00:21:28,2019-01-01 00:28:37,1,1.30,1,N,163,229,1,6.5,0.5,0.5,1.25,0,0.3,9.05,
    52    1,2019-01-01 00:32:01,2019-01-01 00:45:39,1,3.70,1,N,229,7,1,13.5,0.5,0.5,3.7,0,0.3,18.5
    53    """
    54    # pylint: enable=line-too-long
    55  
    56    SAMPLE_ZONE_LOOKUP = """"LocationID","Borough","Zone","service_zone"
    57    7,"Queens","Astoria","Boro Zone"
    58    193,"Queens","Queensbridge/Ravenswood","Boro Zone"
    59    229,"Manhattan","Sutton Place/Turtle Bay North","Yellow Zone"
    60    236,"Manhattan","Upper East Side North","Yellow Zone"
    61    239,"Manhattan","Upper West Side South","Yellow Zone"
    62    246,"Manhattan","West Chelsea/Hudson Yards","Yellow Zone"
    63    """
    64  
    65    def setUp(self):
    66      self.tmpdir = tempfile.TemporaryDirectory()
    67      self.input_path = os.path.join(self.tmpdir.name, 'rides*.csv')
    68      self.lookup_path = os.path.join(self.tmpdir.name, 'lookup.csv')
    69      self.output_path = os.path.join(self.tmpdir.name, 'output.csv')
    70  
    71      # Duplicate sample data in 100 different files to replicate multi-file read
    72      for i in range(100):
    73        with open(os.path.join(self.tmpdir.name, f'rides{i}.csv'), 'w') as fp:
    74          fp.write(self.SAMPLE_RIDES)
    75  
    76      with open(self.lookup_path, 'w') as fp:
    77        fp.write(self.SAMPLE_ZONE_LOOKUP)
    78  
    79    def tearDown(self):
    80      self.tmpdir.cleanup()
    81  
    82    def test_aggregation(self):
    83      # Compute expected result
    84      rides = pd.concat(pd.read_csv(path) for path in glob.glob(self.input_path))
    85      expected_counts = rides.groupby('DOLocationID').passenger_count.sum()
    86  
    87      taxiride.run_aggregation_pipeline(
    88          beam.Pipeline(), self.input_path, self.output_path)
    89  
    90      # Parse result file and compare.
    91      # TODO(https://github.com/apache/beam/issues/20926): taxiride examples
    92      # should produce int sums, not floats
    93      results = []
    94      with open_shards(f'{self.output_path}-*') as result_file:
    95        for line in result_file:
    96          match = re.search(r'(\S+),([0-9\.]+)', line)
    97          if match is not None:
    98            results.append((int(match.group(1)), int(float(match.group(2)))))
    99          elif line.strip():
   100            self.assertEqual(line.strip(), 'DOLocationID,passenger_count')
   101      self.assertEqual(sorted(results), sorted(expected_counts.items()))
   102  
   103    def test_enrich(self):
   104      # Compute expected result
   105      rides = pd.concat(pd.read_csv(path) for path in glob.glob(self.input_path))
   106      zones = pd.read_csv(self.lookup_path)
   107      rides = rides.merge(
   108          zones.set_index('LocationID').Borough,
   109          right_index=True,
   110          left_on='DOLocationID',
   111          how='left')
   112      expected_counts = rides.groupby('Borough').passenger_count.sum()
   113  
   114      taxiride.run_enrich_pipeline(
   115          beam.Pipeline(), self.input_path, self.output_path, self.lookup_path)
   116  
   117      # Parse result file and compare.
   118      # TODO(BEAM-XXXX): taxiride examples should produce int sums, not floats
   119      results = []
   120      with open_shards(f'{self.output_path}-*') as result_file:
   121        for line in result_file:
   122          match = re.search(r'(\S+),([0-9\.]+)', line)
   123          if match is not None:
   124            results.append((match.group(1), int(float(match.group(2)))))
   125          elif line.strip():
   126            self.assertEqual(line.strip(), 'Borough,passenger_count')
   127      self.assertEqual(sorted(results), sorted(expected_counts.items()))
   128  
   129  
   130  if __name__ == '__main__':
   131    logging.getLogger().setLevel(logging.INFO)
   132    unittest.main()