github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/examples/inference/xgboost_iris_classification.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 import argparse 19 import logging 20 from typing import Callable 21 from typing import Iterable 22 from typing import List 23 from typing import Tuple 24 from typing import Union 25 26 import numpy 27 import pandas 28 import scipy 29 from sklearn.datasets import load_iris 30 from sklearn.model_selection import train_test_split 31 32 import apache_beam as beam 33 import datatable 34 import xgboost 35 from apache_beam.ml.inference.base import KeyedModelHandler 36 from apache_beam.ml.inference.base import PredictionResult 37 from apache_beam.ml.inference.base import RunInference 38 from apache_beam.ml.inference.xgboost_inference import XGBoostModelHandlerDatatable 39 from apache_beam.ml.inference.xgboost_inference import XGBoostModelHandlerNumpy 40 from apache_beam.ml.inference.xgboost_inference import XGBoostModelHandlerPandas 41 from apache_beam.ml.inference.xgboost_inference import XGBoostModelHandlerSciPy 42 from apache_beam.options.pipeline_options import PipelineOptions 43 from apache_beam.options.pipeline_options import SetupOptions 44 from apache_beam.runners.runner import PipelineResult 45 46 47 class PostProcessor(beam.DoFn): 48 """Process the PredictionResult to get the predicted label. 49 Returns a comma separated string with true label and predicted label. 50 """ 51 def process(self, element: Tuple[int, PredictionResult]) -> Iterable[str]: 52 label, prediction_result = element 53 prediction = prediction_result.inference 54 yield '{},{}'.format(label, prediction) 55 56 57 def parse_known_args(argv): 58 """Parses args for the workflow.""" 59 parser = argparse.ArgumentParser() 60 parser.add_argument( 61 '--input_type', 62 dest='input_type', 63 required=True, 64 choices=['numpy', 'pandas', 'scipy', 'datatable'], 65 help='Datatype of the input data.') 66 parser.add_argument( 67 '--output', 68 dest='output', 69 required=True, 70 help='Path to save output predictions.') 71 parser.add_argument( 72 '--model_state', 73 dest='model_state', 74 required=True, 75 help='Path to the state of the XGBoost model loaded for Inference.') 76 group = parser.add_mutually_exclusive_group(required=True) 77 group.add_argument('--split', action='store_true', dest='split') 78 group.add_argument('--no_split', action='store_false', dest='split') 79 return parser.parse_known_args(argv) 80 81 82 def load_sklearn_iris_test_data( 83 data_type: Callable, 84 split: bool = True, 85 seed: int = 999) -> List[Union[numpy.array, pandas.DataFrame]]: 86 """ 87 Loads test data from the sklearn Iris dataset in a given format, 88 either in a single or multiple batches. 89 Args: 90 data_type: Datatype of the iris test dataset. 91 split: Split the dataset in different batches or return single batch. 92 seed: Random state for splitting the train and test set. 93 """ 94 dataset = load_iris() 95 _, x_test, _, _ = train_test_split( 96 dataset['data'], dataset['target'], test_size=.2, random_state=seed) 97 98 if split: 99 return [(index, data_type(sample.reshape(1, -1))) for index, 100 sample in enumerate(x_test)] 101 return [(0, data_type(x_test))] 102 103 104 def run( 105 argv=None, save_main_session=True, test_pipeline=None) -> PipelineResult: 106 """ 107 Args: 108 argv: Command line arguments defined for this example. 109 save_main_session: Used for internal testing. 110 test_pipeline: Used for internal testing. 111 """ 112 known_args, pipeline_args = parse_known_args(argv) 113 pipeline_options = PipelineOptions(pipeline_args) 114 pipeline_options.view_as(SetupOptions).save_main_session = save_main_session 115 116 data_types = { 117 'numpy': (numpy.array, XGBoostModelHandlerNumpy), 118 'pandas': (pandas.DataFrame, XGBoostModelHandlerPandas), 119 'scipy': (scipy.sparse.csr_matrix, XGBoostModelHandlerSciPy), 120 'datatable': (datatable.Frame, XGBoostModelHandlerDatatable), 121 } 122 123 input_data_type, model_handler = data_types[known_args.input_type] 124 125 xgboost_model_handler = KeyedModelHandler( 126 model_handler( 127 model_class=xgboost.XGBClassifier, 128 model_state=known_args.model_state)) 129 130 input_data = load_sklearn_iris_test_data( 131 data_type=input_data_type, split=known_args.split) 132 133 pipeline = test_pipeline 134 if not test_pipeline: 135 pipeline = beam.Pipeline(options=pipeline_options) 136 137 predictions = ( 138 pipeline 139 | "ReadInputData" >> beam.Create(input_data) 140 | "RunInference" >> RunInference(xgboost_model_handler) 141 | "PostProcessOutputs" >> beam.ParDo(PostProcessor())) 142 143 _ = predictions | "WriteOutput" >> beam.io.WriteToText( 144 known_args.output, shard_name_template='', append_trailing_newlines=True) 145 146 result = pipeline.run() 147 result.wait_until_finish() 148 return result 149 150 151 if __name__ == '__main__': 152 logging.getLogger().setLevel(logging.INFO) 153 run()