github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/examples/dataframe/taxiride.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Pipelines that use the DataFrame API to process NYC taxiride CSV data.""" 19 20 # pytype: skip-file 21 22 from __future__ import absolute_import 23 24 import argparse 25 import logging 26 27 import apache_beam as beam 28 from apache_beam.dataframe.io import read_csv 29 from apache_beam.options.pipeline_options import PipelineOptions 30 31 ZONE_LOOKUP_PATH = ( 32 "gs://apache-beam-samples/nyc_taxi/misc/taxi+_zone_lookup.csv") 33 34 35 def run_aggregation_pipeline(pipeline, input_path, output_path): 36 # The pipeline will be run on exiting the with block. 37 # [START DataFrame_taxiride_aggregation] 38 with pipeline as p: 39 rides = p | read_csv(input_path) 40 41 # Count the number of passengers dropped off per LocationID 42 agg = rides.groupby('DOLocationID').passenger_count.sum() 43 agg.to_csv(output_path) 44 # [END DataFrame_taxiride_aggregation] 45 46 47 def run_enrich_pipeline( 48 pipeline, input_path, output_path, zone_lookup_path=ZONE_LOOKUP_PATH): 49 """Enrich taxi ride data with zone lookup table and perform a grouped 50 aggregation.""" 51 # The pipeline will be run on exiting the with block. 52 # [START DataFrame_taxiride_enrich] 53 with pipeline as p: 54 rides = p | "Read taxi rides" >> read_csv(input_path) 55 zones = p | "Read zone lookup" >> read_csv(zone_lookup_path) 56 57 # Enrich taxi ride data with boroughs from zone lookup table 58 # Joins on zones.LocationID and rides.DOLocationID, by first making the 59 # former the index for zones. 60 rides = rides.merge( 61 zones.set_index('LocationID').Borough, 62 right_index=True, 63 left_on='DOLocationID', 64 how='left') 65 66 # Sum passengers dropped off per Borough 67 agg = rides.groupby('Borough').passenger_count.sum() 68 agg.to_csv(output_path) 69 # [END DataFrame_taxiride_enrich] 70 71 # A more intuitive alternative to the above merge call, but this option 72 # doesn't preserve index, thus requires non-parallel execution. 73 #rides = rides.merge(zones[['LocationID','Borough']], 74 # how="left", 75 # left_on='DOLocationID', 76 # right_on='LocationID') 77 78 79 def run(argv=None): 80 """Main entry point.""" 81 parser = argparse.ArgumentParser( 82 formatter_class=argparse.ArgumentDefaultsHelpFormatter) 83 parser.add_argument( 84 '--input', 85 dest='input', 86 default='gs://apache-beam-samples/nyc_taxi/misc/sample.csv', 87 help='Input file to process.') 88 parser.add_argument( 89 '--output', 90 dest='output', 91 required=True, 92 help='Output file to write results to.') 93 parser.add_argument( 94 '--zone_lookup', 95 dest='zone_lookup_path', 96 default=ZONE_LOOKUP_PATH, 97 help='Location for taxi zone lookup CSV.') 98 parser.add_argument( 99 '--pipeline', 100 dest='pipeline', 101 default='location_id_agg', 102 help=( 103 "Choice of pipeline to run. Must be one of " 104 "(location_id_agg, borough_enrich).")) 105 106 known_args, pipeline_args = parser.parse_known_args(argv) 107 108 pipeline = beam.Pipeline(options=PipelineOptions(pipeline_args)) 109 110 if known_args.pipeline == 'location_id_agg': 111 run_aggregation_pipeline(pipeline, known_args.input, known_args.output) 112 elif known_args.pipeline == 'borough_enrich': 113 run_enrich_pipeline( 114 pipeline, 115 known_args.input, 116 known_args.output, 117 known_args.zone_lookup_path) 118 else: 119 raise ValueError( 120 f"Unrecognized value for --pipeline: {known_args.pipeline!r}. " 121 "Must be one of ('location_id_agg', 'borough_enrich')") 122 123 124 if __name__ == '__main__': 125 logging.getLogger().setLevel(logging.INFO) 126 run()