github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/examples/dataframe/taxiride_test.py (about) 1 # -*- coding: utf-8 -*- 2 # 3 # Licensed to the Apache Software Foundation (ASF) under one or more 4 # contributor license agreements. See the NOTICE file distributed with 5 # this work for additional information regarding copyright ownership. 6 # The ASF licenses this file to You under the Apache License, Version 2.0 7 # (the "License"); you may not use this file except in compliance with 8 # the License. You may obtain a copy of the License at 9 # 10 # http://www.apache.org/licenses/LICENSE-2.0 11 # 12 # Unless required by applicable law or agreed to in writing, software 13 # distributed under the License is distributed on an "AS IS" BASIS, 14 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 # See the License for the specific language governing permissions and 16 # limitations under the License. 17 # 18 19 """Unit tests for the taxiride example pipelines.""" 20 21 # pytype: skip-file 22 23 from __future__ import absolute_import 24 25 import glob 26 import logging 27 import os 28 import re 29 import tempfile 30 import unittest 31 32 import pandas as pd 33 34 import apache_beam as beam 35 from apache_beam.examples.dataframe import taxiride 36 from apache_beam.testing.util import open_shards 37 38 39 class TaxiRideExampleTest(unittest.TestCase): 40 41 # First 10 lines from gs://apache-beam-samples/nyc_taxi/misc/sample.csv 42 # pylint: disable=line-too-long 43 SAMPLE_RIDES = """VendorID,tpep_pickup_datetime,tpep_dropoff_datetime,passenger_count,trip_distance,RatecodeID,store_and_fwd_flag,PULocationID,DOLocationID,payment_type,fare_amount,extra,mta_tax,tip_amount,tolls_amount,improvement_surcharge,total_amount,congestion_surcharge 44 1,2019-01-01 00:46:40,2019-01-01 00:53:20,1,1.50,1,N,151,239,1,7,0.5,0.5,1.65,0,0.3,9.95, 45 1,2019-01-01 00:59:47,2019-01-01 01:18:59,1,2.60,1,N,239,246,1,14,0.5,0.5,1,0,0.3,16.3, 46 2,2018-12-21 13:48:30,2018-12-21 13:52:40,3,.00,1,N,236,236,1,4.5,0.5,0.5,0,0,0.3,5.8, 47 2,2018-11-28 15:52:25,2018-11-28 15:55:45,5,.00,1,N,193,193,2,3.5,0.5,0.5,0,0,0.3,7.55, 48 2,2018-11-28 15:56:57,2018-11-28 15:58:33,5,.00,2,N,193,193,2,52,0,0.5,0,0,0.3,55.55, 49 2,2018-11-28 16:25:49,2018-11-28 16:28:26,5,.00,1,N,193,193,2,3.5,0.5,0.5,0,5.76,0.3,13.31, 50 2,2018-11-28 16:29:37,2018-11-28 16:33:43,5,.00,2,N,193,193,2,52,0,0.5,0,0,0.3,55.55, 51 1,2019-01-01 00:21:28,2019-01-01 00:28:37,1,1.30,1,N,163,229,1,6.5,0.5,0.5,1.25,0,0.3,9.05, 52 1,2019-01-01 00:32:01,2019-01-01 00:45:39,1,3.70,1,N,229,7,1,13.5,0.5,0.5,3.7,0,0.3,18.5 53 """ 54 # pylint: enable=line-too-long 55 56 SAMPLE_ZONE_LOOKUP = """"LocationID","Borough","Zone","service_zone" 57 7,"Queens","Astoria","Boro Zone" 58 193,"Queens","Queensbridge/Ravenswood","Boro Zone" 59 229,"Manhattan","Sutton Place/Turtle Bay North","Yellow Zone" 60 236,"Manhattan","Upper East Side North","Yellow Zone" 61 239,"Manhattan","Upper West Side South","Yellow Zone" 62 246,"Manhattan","West Chelsea/Hudson Yards","Yellow Zone" 63 """ 64 65 def setUp(self): 66 self.tmpdir = tempfile.TemporaryDirectory() 67 self.input_path = os.path.join(self.tmpdir.name, 'rides*.csv') 68 self.lookup_path = os.path.join(self.tmpdir.name, 'lookup.csv') 69 self.output_path = os.path.join(self.tmpdir.name, 'output.csv') 70 71 # Duplicate sample data in 100 different files to replicate multi-file read 72 for i in range(100): 73 with open(os.path.join(self.tmpdir.name, f'rides{i}.csv'), 'w') as fp: 74 fp.write(self.SAMPLE_RIDES) 75 76 with open(self.lookup_path, 'w') as fp: 77 fp.write(self.SAMPLE_ZONE_LOOKUP) 78 79 def tearDown(self): 80 self.tmpdir.cleanup() 81 82 def test_aggregation(self): 83 # Compute expected result 84 rides = pd.concat(pd.read_csv(path) for path in glob.glob(self.input_path)) 85 expected_counts = rides.groupby('DOLocationID').passenger_count.sum() 86 87 taxiride.run_aggregation_pipeline( 88 beam.Pipeline(), self.input_path, self.output_path) 89 90 # Parse result file and compare. 91 # TODO(https://github.com/apache/beam/issues/20926): taxiride examples 92 # should produce int sums, not floats 93 results = [] 94 with open_shards(f'{self.output_path}-*') as result_file: 95 for line in result_file: 96 match = re.search(r'(\S+),([0-9\.]+)', line) 97 if match is not None: 98 results.append((int(match.group(1)), int(float(match.group(2))))) 99 elif line.strip(): 100 self.assertEqual(line.strip(), 'DOLocationID,passenger_count') 101 self.assertEqual(sorted(results), sorted(expected_counts.items())) 102 103 def test_enrich(self): 104 # Compute expected result 105 rides = pd.concat(pd.read_csv(path) for path in glob.glob(self.input_path)) 106 zones = pd.read_csv(self.lookup_path) 107 rides = rides.merge( 108 zones.set_index('LocationID').Borough, 109 right_index=True, 110 left_on='DOLocationID', 111 how='left') 112 expected_counts = rides.groupby('Borough').passenger_count.sum() 113 114 taxiride.run_enrich_pipeline( 115 beam.Pipeline(), self.input_path, self.output_path, self.lookup_path) 116 117 # Parse result file and compare. 118 # TODO(BEAM-XXXX): taxiride examples should produce int sums, not floats 119 results = [] 120 with open_shards(f'{self.output_path}-*') as result_file: 121 for line in result_file: 122 match = re.search(r'(\S+),([0-9\.]+)', line) 123 if match is not None: 124 results.append((match.group(1), int(float(match.group(2))))) 125 elif line.strip(): 126 self.assertEqual(line.strip(), 'Borough,passenger_count') 127 self.assertEqual(sorted(results), sorted(expected_counts.items())) 128 129 130 if __name__ == '__main__': 131 logging.getLogger().setLevel(logging.INFO) 132 unittest.main()