github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/examples/dataframe/flight_delays.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """A pipeline using dataframes to compute typical flight delay times."""
    19  
    20  # pytype: skip-file
    21  
    22  from __future__ import absolute_import
    23  
    24  import argparse
    25  import logging
    26  
    27  import apache_beam as beam
    28  from apache_beam.dataframe.convert import to_dataframe
    29  from apache_beam.options.pipeline_options import PipelineOptions
    30  
    31  
    32  def get_mean_delay_at_top_airports(airline_df):
    33    arr = airline_df.rename(columns={
    34        'arrival_airport': 'airport'
    35    }).airport.value_counts()
    36    dep = airline_df.rename(columns={
    37        'departure_airport': 'airport'
    38    }).airport.value_counts()
    39    total = arr + dep
    40    # Note we keep all to include duplicates - this ensures the result is
    41    # deterministic.
    42    # NaNs can be included in the output in pandas 1.4.0 and above, so we
    43    # explicitly drop them.
    44    top_airports = total.nlargest(10, keep='all').dropna()
    45    at_top_airports = airline_df['arrival_airport'].isin(
    46        top_airports.index.values)
    47    return airline_df[at_top_airports].mean()
    48  
    49  
    50  def input_date(date):
    51    import datetime
    52    parsed = datetime.datetime.strptime(date, '%Y-%m-%d')
    53    if (parsed > datetime.datetime(2012, 12, 31) or
    54        parsed < datetime.datetime(2002, 1, 1)):
    55      raise ValueError("There's only data from 2002-01-01 to 2012-12-31")
    56    return date
    57  
    58  
    59  def run_flight_delay_pipeline(
    60      pipeline, start_date=None, end_date=None, output=None):
    61    query = f"""
    62    SELECT
    63      FlightDate AS date,
    64      IATA_CODE_Reporting_Airline AS airline,
    65      Origin AS departure_airport,
    66      Dest AS arrival_airport,
    67      DepDelay AS departure_delay,
    68      ArrDelay AS arrival_delay
    69    FROM `apache-beam-testing.airline_ontime_data.flights`
    70    WHERE
    71      FlightDate >= '{start_date}' AND FlightDate <= '{end_date}' AND
    72      DepDelay IS NOT NULL AND ArrDelay IS NOT NULL
    73    """
    74  
    75    # Import this here to avoid pickling the main session.
    76    import time
    77    from apache_beam import window
    78  
    79    def to_unixtime(s):
    80      return time.mktime(s.timetuple())
    81  
    82    # The pipeline will be run on exiting the with block.
    83    with pipeline as p:
    84      tbl = (
    85          p
    86          | 'read table' >> beam.io.ReadFromBigQuery(
    87              query=query, use_standard_sql=True)
    88          | 'assign timestamp' >>
    89          beam.Map(lambda x: window.TimestampedValue(x, to_unixtime(x['date'])))
    90          # Use beam.Select to make sure data has a schema
    91          # The casts in lambdas ensure data types are properly inferred
    92          | 'set schema' >> beam.Select(
    93              date=lambda x: str(x['date']),
    94              airline=lambda x: str(x['airline']),
    95              departure_airport=lambda x: str(x['departure_airport']),
    96              arrival_airport=lambda x: str(x['arrival_airport']),
    97              departure_delay=lambda x: float(x['departure_delay']),
    98              arrival_delay=lambda x: float(x['arrival_delay'])))
    99  
   100      daily = tbl | 'daily windows' >> beam.WindowInto(
   101          beam.window.FixedWindows(60 * 60 * 24))
   102  
   103      # group the flights data by carrier
   104      df = to_dataframe(daily)
   105      result = df.groupby('airline').apply(get_mean_delay_at_top_airports)
   106      result.to_csv(output)
   107  
   108  
   109  def run(argv=None):
   110    """Main entry point; defines and runs the flight delay pipeline."""
   111    parser = argparse.ArgumentParser()
   112    parser.add_argument(
   113        '--start_date',
   114        dest='start_date',
   115        type=input_date,
   116        default='2012-12-22',
   117        help='YYYY-MM-DD lower bound (inclusive) for input dataset.')
   118    parser.add_argument(
   119        '--end_date',
   120        dest='end_date',
   121        type=input_date,
   122        default='2012-12-26',
   123        help='YYYY-MM-DD upper bound (inclusive) for input dataset.')
   124    parser.add_argument(
   125        '--output',
   126        dest='output',
   127        required=True,
   128        help='Location to write the output.')
   129    known_args, pipeline_args = parser.parse_known_args(argv)
   130  
   131    run_flight_delay_pipeline(
   132        beam.Pipeline(options=PipelineOptions(pipeline_args)),
   133        start_date=known_args.start_date,
   134        end_date=known_args.end_date,
   135        output=known_args.output)
   136  
   137  
   138  if __name__ == '__main__':
   139    logging.getLogger().setLevel(logging.INFO)
   140    run()