github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/examples/dataframe/taxiride.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Pipelines that use the DataFrame API to process NYC taxiride CSV data."""
    19  
    20  # pytype: skip-file
    21  
    22  from __future__ import absolute_import
    23  
    24  import argparse
    25  import logging
    26  
    27  import apache_beam as beam
    28  from apache_beam.dataframe.io import read_csv
    29  from apache_beam.options.pipeline_options import PipelineOptions
    30  
    31  ZONE_LOOKUP_PATH = (
    32      "gs://apache-beam-samples/nyc_taxi/misc/taxi+_zone_lookup.csv")
    33  
    34  
    35  def run_aggregation_pipeline(pipeline, input_path, output_path):
    36    # The pipeline will be run on exiting the with block.
    37    # [START DataFrame_taxiride_aggregation]
    38    with pipeline as p:
    39      rides = p | read_csv(input_path)
    40  
    41      # Count the number of passengers dropped off per LocationID
    42      agg = rides.groupby('DOLocationID').passenger_count.sum()
    43      agg.to_csv(output_path)
    44      # [END DataFrame_taxiride_aggregation]
    45  
    46  
    47  def run_enrich_pipeline(
    48      pipeline, input_path, output_path, zone_lookup_path=ZONE_LOOKUP_PATH):
    49    """Enrich taxi ride data with zone lookup table and perform a grouped
    50    aggregation."""
    51    # The pipeline will be run on exiting the with block.
    52    # [START DataFrame_taxiride_enrich]
    53    with pipeline as p:
    54      rides = p | "Read taxi rides" >> read_csv(input_path)
    55      zones = p | "Read zone lookup" >> read_csv(zone_lookup_path)
    56  
    57      # Enrich taxi ride data with boroughs from zone lookup table
    58      # Joins on zones.LocationID and rides.DOLocationID, by first making the
    59      # former the index for zones.
    60      rides = rides.merge(
    61          zones.set_index('LocationID').Borough,
    62          right_index=True,
    63          left_on='DOLocationID',
    64          how='left')
    65  
    66      # Sum passengers dropped off per Borough
    67      agg = rides.groupby('Borough').passenger_count.sum()
    68      agg.to_csv(output_path)
    69      # [END DataFrame_taxiride_enrich]
    70  
    71      # A more intuitive alternative to the above merge call, but this option
    72      # doesn't preserve index, thus requires non-parallel execution.
    73      #rides = rides.merge(zones[['LocationID','Borough']],
    74      #                    how="left",
    75      #                    left_on='DOLocationID',
    76      #                    right_on='LocationID')
    77  
    78  
    79  def run(argv=None):
    80    """Main entry point."""
    81    parser = argparse.ArgumentParser(
    82        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    83    parser.add_argument(
    84        '--input',
    85        dest='input',
    86        default='gs://apache-beam-samples/nyc_taxi/misc/sample.csv',
    87        help='Input file to process.')
    88    parser.add_argument(
    89        '--output',
    90        dest='output',
    91        required=True,
    92        help='Output file to write results to.')
    93    parser.add_argument(
    94        '--zone_lookup',
    95        dest='zone_lookup_path',
    96        default=ZONE_LOOKUP_PATH,
    97        help='Location for taxi zone lookup CSV.')
    98    parser.add_argument(
    99        '--pipeline',
   100        dest='pipeline',
   101        default='location_id_agg',
   102        help=(
   103            "Choice of pipeline to run. Must be one of "
   104            "(location_id_agg, borough_enrich)."))
   105  
   106    known_args, pipeline_args = parser.parse_known_args(argv)
   107  
   108    pipeline = beam.Pipeline(options=PipelineOptions(pipeline_args))
   109  
   110    if known_args.pipeline == 'location_id_agg':
   111      run_aggregation_pipeline(pipeline, known_args.input, known_args.output)
   112    elif known_args.pipeline == 'borough_enrich':
   113      run_enrich_pipeline(
   114          pipeline,
   115          known_args.input,
   116          known_args.output,
   117          known_args.zone_lookup_path)
   118    else:
   119      raise ValueError(
   120          f"Unrecognized value for --pipeline: {known_args.pipeline!r}. "
   121          "Must be one of ('location_id_agg', 'borough_enrich')")
   122  
   123  
   124  if __name__ == '__main__':
   125    logging.getLogger().setLevel(logging.INFO)
   126    run()