github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/examples/dataframe/flight_delays_it_test.py (about)

     1  # -*- coding: utf-8 -*-
     2  #
     3  # Licensed to the Apache Software Foundation (ASF) under one or more
     4  # contributor license agreements.  See the NOTICE file distributed with
     5  # this work for additional information regarding copyright ownership.
     6  # The ASF licenses this file to You under the Apache License, Version 2.0
     7  # (the "License"); you may not use this file except in compliance with
     8  # the License.  You may obtain a copy of the License at
     9  #
    10  #    http://www.apache.org/licenses/LICENSE-2.0
    11  #
    12  # Unless required by applicable law or agreed to in writing, software
    13  # distributed under the License is distributed on an "AS IS" BASIS,
    14  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    15  # See the License for the specific language governing permissions and
    16  # limitations under the License.
    17  #
    18  
    19  """Test for the flight delay example."""
    20  
    21  # pytype: skip-file
    22  
    23  from __future__ import absolute_import
    24  
    25  import logging
    26  import os
    27  import unittest
    28  import uuid
    29  
    30  import pandas as pd
    31  import pytest
    32  
    33  from apache_beam.examples.dataframe import flight_delays
    34  from apache_beam.io.filesystems import FileSystems
    35  from apache_beam.testing.test_pipeline import TestPipeline
    36  
    37  
    38  class FlightDelaysTest(unittest.TestCase):
    39    EXPECTED = {
    40        '2012-12-23': [
    41            ('AA', 20.082559339525282, 12.825593395252838),
    42            ('AS', 5.0456273764258555, 1.0722433460076046),
    43            ('B6', 20.646569646569645, 16.405405405405407),
    44            ('DL', 5.241148325358852, -3.2401913875598085),
    45            ('EV', 9.982053838484546, 4.40777666999003),
    46            ('F9', 23.67883211678832, 25.27007299270073),
    47            ('FL', 4.4602272727272725, -0.8352272727272727),
    48            ('HA', -1.0829015544041452, 0.010362694300518135),
    49            ('MQ', 8.912912912912914, 3.6936936936936937),
    50            ('OO', 30.526699029126213, 31.17961165048544),
    51            ('UA', 19.142555438225976, 11.07180570221753),
    52            ('US', 3.092541436464088, -2.350828729281768),
    53            ('VX', 62.755102040816325, 62.61224489795919),
    54            ('WN', 12.05824508320726, 6.713313161875946),
    55            ('YV', 16.155844155844157, 13.376623376623376),
    56        ],
    57        '2012-12-24': [
    58            ('AA', 7.049086757990867, -1.5970319634703196),
    59            ('AS', 0.5917602996254682, -2.2659176029962547),
    60            ('B6', 8.070993914807302, 2.73630831643002),
    61            ('DL', 3.700745473908413, -2.2396166134185305),
    62            ('EV', 7.322115384615385, 2.3653846153846154),
    63            ('F9', 13.786764705882351, 15.5),
    64            ('FL', 2.416909620991253, 2.224489795918368),
    65            ('HA', -2.6785714285714284, -2.4744897959183674),
    66            ('MQ', 15.818181818181818, 9.935828877005347),
    67            ('OO', 10.902374670184695, 10.08575197889182),
    68            ('UA', 10.935406698564593, -1.3337320574162679),
    69            ('US', 1.369281045751634, -1.4101307189542485),
    70            ('VX', 3.841666666666667, -2.4166666666666665),
    71            ('WN', 7.3715753424657535, 0.348458904109589),
    72            ('YV', 0.32, 0.78),
    73        ],
    74        '2012-12-25': [
    75            ('AA', 23.551581843191197, 35.62585969738652),
    76            ('AS', 3.4816326530612245, 0.27346938775510204),
    77            ('B6', 9.10590631364562, 3.989816700610998),
    78            ('DL', 2.2863795110593714, -3.668218859138533),
    79            ('EV', 17.35576923076923, 16.414835164835164),
    80            ('F9', 19.38, 21.786666666666665),
    81            ('FL', 1.3823529411764706, 0.9205882352941176),
    82            ('HA', -4.725806451612903, -3.9946236559139785),
    83            ('MQ', 32.527716186252775, 44.148558758314856),
    84            ('OO', 15.788595271210012, 16.617524339360223),
    85            ('UA', 16.663145539906104, 10.772300469483568),
    86            ('US', 2.7953216374269005, 0.2236842105263158),
    87            ('VX', 23.62878787878788, 23.636363636363637),
    88            ('WN', 14.423791821561338, 10.142193308550183),
    89            ('YV', 11.256302521008404, 11.659663865546218),
    90        ],
    91    }
    92  
    93    def setUp(self):
    94      self.test_pipeline = TestPipeline(is_integration_test=True)
    95      self.outdir = (
    96          self.test_pipeline.get_option('temp_location') + '/flight_delays_it-' +
    97          str(uuid.uuid4()))
    98      self.output_path = os.path.join(self.outdir, 'output.csv')
    99  
   100    def tearDown(self):
   101      FileSystems.delete([self.outdir + '/'])
   102  
   103    @pytest.mark.examples_postcommit
   104    @pytest.mark.it_postcommit
   105    def test_flight_delays(self):
   106      flight_delays.run_flight_delay_pipeline(
   107          self.test_pipeline,
   108          start_date='2012-12-23',
   109          end_date='2012-12-25',
   110          output=self.output_path)
   111  
   112      def read_csv(path):
   113        with FileSystems.open(path) as fp:
   114          return pd.read_csv(fp)
   115  
   116      # Parse result file and compare.
   117      for date, expectation in self.EXPECTED.items():
   118        result_df = pd.concat(
   119            read_csv(metadata.path) for metadata in FileSystems.match(
   120                [f'{self.output_path}-{date}*'])[0].metadata_list)
   121        result_df = result_df.sort_values('airline').reset_index(drop=True)
   122  
   123        expected_df = pd.DataFrame(
   124            expectation, columns=['airline', 'departure_delay', 'arrival_delay'])
   125        expected_df = expected_df.sort_values('airline').reset_index(drop=True)
   126  
   127        try:
   128          pd.testing.assert_frame_equal(result_df, expected_df)
   129        except AssertionError as e:
   130          raise AssertionError(
   131              f"date={date!r} result DataFrame:\n\n"
   132              f"{result_df}\n\n"
   133              "Differs from Expectation:\n\n"
   134              f"{expected_df}") from e
   135  
   136  
   137  if __name__ == '__main__':
   138    logging.getLogger().setLevel(logging.INFO)
   139    unittest.main()