github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/examples/inference/xgboost_iris_classification.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  import argparse
    19  import logging
    20  from typing import Callable
    21  from typing import Iterable
    22  from typing import List
    23  from typing import Tuple
    24  from typing import Union
    25  
    26  import numpy
    27  import pandas
    28  import scipy
    29  from sklearn.datasets import load_iris
    30  from sklearn.model_selection import train_test_split
    31  
    32  import apache_beam as beam
    33  import datatable
    34  import xgboost
    35  from apache_beam.ml.inference.base import KeyedModelHandler
    36  from apache_beam.ml.inference.base import PredictionResult
    37  from apache_beam.ml.inference.base import RunInference
    38  from apache_beam.ml.inference.xgboost_inference import XGBoostModelHandlerDatatable
    39  from apache_beam.ml.inference.xgboost_inference import XGBoostModelHandlerNumpy
    40  from apache_beam.ml.inference.xgboost_inference import XGBoostModelHandlerPandas
    41  from apache_beam.ml.inference.xgboost_inference import XGBoostModelHandlerSciPy
    42  from apache_beam.options.pipeline_options import PipelineOptions
    43  from apache_beam.options.pipeline_options import SetupOptions
    44  from apache_beam.runners.runner import PipelineResult
    45  
    46  
    47  class PostProcessor(beam.DoFn):
    48    """Process the PredictionResult to get the predicted label.
    49    Returns a comma separated string with true label and predicted label.
    50    """
    51    def process(self, element: Tuple[int, PredictionResult]) -> Iterable[str]:
    52      label, prediction_result = element
    53      prediction = prediction_result.inference
    54      yield '{},{}'.format(label, prediction)
    55  
    56  
    57  def parse_known_args(argv):
    58    """Parses args for the workflow."""
    59    parser = argparse.ArgumentParser()
    60    parser.add_argument(
    61        '--input_type',
    62        dest='input_type',
    63        required=True,
    64        choices=['numpy', 'pandas', 'scipy', 'datatable'],
    65        help='Datatype of the input data.')
    66    parser.add_argument(
    67        '--output',
    68        dest='output',
    69        required=True,
    70        help='Path to save output predictions.')
    71    parser.add_argument(
    72        '--model_state',
    73        dest='model_state',
    74        required=True,
    75        help='Path to the state of the XGBoost model loaded for Inference.')
    76    group = parser.add_mutually_exclusive_group(required=True)
    77    group.add_argument('--split', action='store_true', dest='split')
    78    group.add_argument('--no_split', action='store_false', dest='split')
    79    return parser.parse_known_args(argv)
    80  
    81  
    82  def load_sklearn_iris_test_data(
    83      data_type: Callable,
    84      split: bool = True,
    85      seed: int = 999) -> List[Union[numpy.array, pandas.DataFrame]]:
    86    """
    87      Loads test data from the sklearn Iris dataset in a given format,
    88      either in a single or multiple batches.
    89      Args:
    90        data_type: Datatype of the iris test dataset.
    91        split: Split the dataset in different batches or return single batch.
    92        seed: Random state for splitting the train and test set.
    93    """
    94    dataset = load_iris()
    95    _, x_test, _, _ = train_test_split(
    96        dataset['data'], dataset['target'], test_size=.2, random_state=seed)
    97  
    98    if split:
    99      return [(index, data_type(sample.reshape(1, -1))) for index,
   100              sample in enumerate(x_test)]
   101    return [(0, data_type(x_test))]
   102  
   103  
   104  def run(
   105      argv=None, save_main_session=True, test_pipeline=None) -> PipelineResult:
   106    """
   107      Args:
   108        argv: Command line arguments defined for this example.
   109        save_main_session: Used for internal testing.
   110        test_pipeline: Used for internal testing.
   111    """
   112    known_args, pipeline_args = parse_known_args(argv)
   113    pipeline_options = PipelineOptions(pipeline_args)
   114    pipeline_options.view_as(SetupOptions).save_main_session = save_main_session
   115  
   116    data_types = {
   117        'numpy': (numpy.array, XGBoostModelHandlerNumpy),
   118        'pandas': (pandas.DataFrame, XGBoostModelHandlerPandas),
   119        'scipy': (scipy.sparse.csr_matrix, XGBoostModelHandlerSciPy),
   120        'datatable': (datatable.Frame, XGBoostModelHandlerDatatable),
   121    }
   122  
   123    input_data_type, model_handler = data_types[known_args.input_type]
   124  
   125    xgboost_model_handler = KeyedModelHandler(
   126        model_handler(
   127            model_class=xgboost.XGBClassifier,
   128            model_state=known_args.model_state))
   129  
   130    input_data = load_sklearn_iris_test_data(
   131        data_type=input_data_type, split=known_args.split)
   132  
   133    pipeline = test_pipeline
   134    if not test_pipeline:
   135      pipeline = beam.Pipeline(options=pipeline_options)
   136  
   137    predictions = (
   138        pipeline
   139        | "ReadInputData" >> beam.Create(input_data)
   140        | "RunInference" >> RunInference(xgboost_model_handler)
   141        | "PostProcessOutputs" >> beam.ParDo(PostProcessor()))
   142  
   143    _ = predictions | "WriteOutput" >> beam.io.WriteToText(
   144        known_args.output, shard_name_template='', append_trailing_newlines=True)
   145  
   146    result = pipeline.run()
   147    result.wait_until_finish()
   148    return result
   149  
   150  
   151  if __name__ == '__main__':
   152    logging.getLogger().setLevel(logging.INFO)
   153    run()