github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/ml/inference/sklearn_inference_it_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """End-to-End test for Sklearn Inference"""
    19  
    20  import logging
    21  import re
    22  import unittest
    23  import uuid
    24  
    25  import pytest
    26  
    27  from apache_beam.examples.inference import sklearn_japanese_housing_regression
    28  from apache_beam.examples.inference import sklearn_mnist_classification
    29  from apache_beam.io.filesystems import FileSystems
    30  from apache_beam.testing.test_pipeline import TestPipeline
    31  
    32  # pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports, unused-import
    33  try:
    34    from apache_beam.io.gcp.gcsfilesystem import GCSFileSystem
    35  except ImportError:
    36    raise unittest.SkipTest('GCP dependencies are not installed')
    37  
    38  
    39  def process_outputs(filepath):
    40    with FileSystems().open(filepath) as f:
    41      lines = f.readlines()
    42    lines = [l.decode('utf-8').strip('\n') for l in lines]
    43    return lines
    44  
    45  
    46  def file_lines_sorted(filepath):
    47    with FileSystems().open(filepath) as f:
    48      lines = f.readlines()
    49    lines = [l.decode('utf-8').strip('\n') for l in lines]
    50    return sorted(lines)
    51  
    52  
    53  @pytest.mark.uses_sklearn
    54  @pytest.mark.it_postcommit
    55  class SklearnInference(unittest.TestCase):
    56    def test_sklearn_mnist_classification(self):
    57      test_pipeline = TestPipeline(is_integration_test=True)
    58      input_file = 'gs://apache-beam-ml/testing/inputs/it_mnist_data.csv'
    59      output_file_dir = 'gs://temp-storage-for-end-to-end-tests'
    60      output_file = '/'.join([output_file_dir, str(uuid.uuid4()), 'result.txt'])
    61      model_path = 'gs://apache-beam-ml/models/mnist_model_svm.pickle'
    62      extra_opts = {
    63          'input': input_file,
    64          'output': output_file,
    65          'model_path': model_path,
    66      }
    67      sklearn_mnist_classification.run(
    68          test_pipeline.get_full_options_as_args(**extra_opts),
    69          save_main_session=False)
    70      self.assertEqual(FileSystems().exists(output_file), True)
    71  
    72      expected_output_filepath = 'gs://apache-beam-ml/testing/expected_outputs/test_sklearn_mnist_classification_actuals.txt'  # pylint: disable=line-too-long
    73      expected_outputs = process_outputs(expected_output_filepath)
    74  
    75      predicted_outputs = process_outputs(output_file)
    76      self.assertEqual(len(expected_outputs), len(predicted_outputs))
    77  
    78      predictions_dict = {}
    79      for i in range(len(predicted_outputs)):
    80        true_label, prediction = predicted_outputs[i].split(',')
    81        predictions_dict[true_label] = prediction
    82  
    83      for i in range(len(expected_outputs)):
    84        true_label, expected_prediction = expected_outputs[i].split(',')
    85        self.assertEqual(predictions_dict[true_label], expected_prediction)
    86  
    87    def test_sklearn_regression(self):
    88      test_pipeline = TestPipeline(is_integration_test=True)
    89      input_file = 'gs://apache-beam-ml/testing/inputs/japanese_housing_test_data.csv'  # pylint: disable=line-too-long
    90      output_file_dir = 'gs://temp-storage-for-end-to-end-tests'
    91      output_file = '/'.join([output_file_dir, str(uuid.uuid4()), 'result.txt'])
    92      model_path = 'gs://apache-beam-ml/models/japanese_housing/'
    93      extra_opts = {
    94          'input': input_file,
    95          'output': output_file,
    96          'model_path': model_path,
    97      }
    98      sklearn_japanese_housing_regression.run(
    99          test_pipeline.get_full_options_as_args(**extra_opts),
   100          save_main_session=False)
   101      self.assertEqual(FileSystems().exists(output_file), True)
   102  
   103      expected_output_filepath = 'gs://apache-beam-ml/testing/expected_outputs/japanese_housing_subset.txt'  # pylint: disable=line-too-long
   104      expected_outputs = file_lines_sorted(expected_output_filepath)
   105      actual_outputs = file_lines_sorted(output_file)
   106      self.assertEqual(len(expected_outputs), len(actual_outputs))
   107  
   108      for expected, actual in zip(expected_outputs, actual_outputs):
   109        expected_true, expected_predict = re.findall(r'\d+', expected)
   110        actual_true, actual_predict = re.findall(r'\d+', actual)
   111        # actual_true is the y value from the input csv file.
   112        # Therefore it should be an exact match to expected_true.
   113        self.assertEqual(actual_true, expected_true)
   114        # predictions might not be exactly equal due to differences between
   115        # environments. This code validates they are within 10 percent.
   116        percent_diff = abs(float(expected_predict) - float(actual_predict)
   117                           ) / float(expected_predict) * 100.0
   118        self.assertLess(percent_diff, 10)
   119  
   120  
   121  if __name__ == '__main__':
   122    logging.getLogger().setLevel(logging.DEBUG)
   123    unittest.main()