github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/ml/inference/sklearn_inference_it_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """End-to-End test for Sklearn Inference""" 19 20 import logging 21 import re 22 import unittest 23 import uuid 24 25 import pytest 26 27 from apache_beam.examples.inference import sklearn_japanese_housing_regression 28 from apache_beam.examples.inference import sklearn_mnist_classification 29 from apache_beam.io.filesystems import FileSystems 30 from apache_beam.testing.test_pipeline import TestPipeline 31 32 # pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports, unused-import 33 try: 34 from apache_beam.io.gcp.gcsfilesystem import GCSFileSystem 35 except ImportError: 36 raise unittest.SkipTest('GCP dependencies are not installed') 37 38 39 def process_outputs(filepath): 40 with FileSystems().open(filepath) as f: 41 lines = f.readlines() 42 lines = [l.decode('utf-8').strip('\n') for l in lines] 43 return lines 44 45 46 def file_lines_sorted(filepath): 47 with FileSystems().open(filepath) as f: 48 lines = f.readlines() 49 lines = [l.decode('utf-8').strip('\n') for l in lines] 50 return sorted(lines) 51 52 53 @pytest.mark.uses_sklearn 54 @pytest.mark.it_postcommit 55 class SklearnInference(unittest.TestCase): 56 def test_sklearn_mnist_classification(self): 57 test_pipeline = TestPipeline(is_integration_test=True) 58 input_file = 'gs://apache-beam-ml/testing/inputs/it_mnist_data.csv' 59 output_file_dir = 'gs://temp-storage-for-end-to-end-tests' 60 output_file = '/'.join([output_file_dir, str(uuid.uuid4()), 'result.txt']) 61 model_path = 'gs://apache-beam-ml/models/mnist_model_svm.pickle' 62 extra_opts = { 63 'input': input_file, 64 'output': output_file, 65 'model_path': model_path, 66 } 67 sklearn_mnist_classification.run( 68 test_pipeline.get_full_options_as_args(**extra_opts), 69 save_main_session=False) 70 self.assertEqual(FileSystems().exists(output_file), True) 71 72 expected_output_filepath = 'gs://apache-beam-ml/testing/expected_outputs/test_sklearn_mnist_classification_actuals.txt' # pylint: disable=line-too-long 73 expected_outputs = process_outputs(expected_output_filepath) 74 75 predicted_outputs = process_outputs(output_file) 76 self.assertEqual(len(expected_outputs), len(predicted_outputs)) 77 78 predictions_dict = {} 79 for i in range(len(predicted_outputs)): 80 true_label, prediction = predicted_outputs[i].split(',') 81 predictions_dict[true_label] = prediction 82 83 for i in range(len(expected_outputs)): 84 true_label, expected_prediction = expected_outputs[i].split(',') 85 self.assertEqual(predictions_dict[true_label], expected_prediction) 86 87 def test_sklearn_regression(self): 88 test_pipeline = TestPipeline(is_integration_test=True) 89 input_file = 'gs://apache-beam-ml/testing/inputs/japanese_housing_test_data.csv' # pylint: disable=line-too-long 90 output_file_dir = 'gs://temp-storage-for-end-to-end-tests' 91 output_file = '/'.join([output_file_dir, str(uuid.uuid4()), 'result.txt']) 92 model_path = 'gs://apache-beam-ml/models/japanese_housing/' 93 extra_opts = { 94 'input': input_file, 95 'output': output_file, 96 'model_path': model_path, 97 } 98 sklearn_japanese_housing_regression.run( 99 test_pipeline.get_full_options_as_args(**extra_opts), 100 save_main_session=False) 101 self.assertEqual(FileSystems().exists(output_file), True) 102 103 expected_output_filepath = 'gs://apache-beam-ml/testing/expected_outputs/japanese_housing_subset.txt' # pylint: disable=line-too-long 104 expected_outputs = file_lines_sorted(expected_output_filepath) 105 actual_outputs = file_lines_sorted(output_file) 106 self.assertEqual(len(expected_outputs), len(actual_outputs)) 107 108 for expected, actual in zip(expected_outputs, actual_outputs): 109 expected_true, expected_predict = re.findall(r'\d+', expected) 110 actual_true, actual_predict = re.findall(r'\d+', actual) 111 # actual_true is the y value from the input csv file. 112 # Therefore it should be an exact match to expected_true. 113 self.assertEqual(actual_true, expected_true) 114 # predictions might not be exactly equal due to differences between 115 # environments. This code validates they are within 10 percent. 116 percent_diff = abs(float(expected_predict) - float(actual_predict) 117 ) / float(expected_predict) * 100.0 118 self.assertLess(percent_diff, 10) 119 120 121 if __name__ == '__main__': 122 logging.getLogger().setLevel(logging.DEBUG) 123 unittest.main()