github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/examples/inference/sklearn_mnist_classification.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """A pipeline that uses RunInference API to classify MNIST data.
    19  
    20  This pipeline takes a text file in which data is comma separated ints. The first
    21  column would be the true label and the rest would be the pixel values. The data
    22  is processed and then a model trained on the MNIST data would be used to perform
    23  the inference. The pipeline writes the prediction to an output file in which
    24  users can then compare against the true label.
    25  """
    26  
    27  import argparse
    28  import logging
    29  import os
    30  from typing import Iterable
    31  from typing import List
    32  from typing import Tuple
    33  
    34  import apache_beam as beam
    35  from apache_beam.ml.inference.base import KeyedModelHandler
    36  from apache_beam.ml.inference.base import PredictionResult
    37  from apache_beam.ml.inference.base import RunInference
    38  from apache_beam.ml.inference.sklearn_inference import ModelFileType
    39  from apache_beam.ml.inference.sklearn_inference import SklearnModelHandlerNumpy
    40  from apache_beam.options.pipeline_options import PipelineOptions
    41  from apache_beam.options.pipeline_options import SetupOptions
    42  from apache_beam.runners.runner import PipelineResult
    43  
    44  
    45  def process_input(row: str) -> Tuple[int, List[int]]:
    46    data = row.split(',')
    47    label, pixels = int(data[0]), data[1:]
    48    pixels = [int(pixel) for pixel in pixels]
    49    return label, pixels
    50  
    51  
    52  class PostProcessor(beam.DoFn):
    53    """Process the PredictionResult to get the predicted label.
    54    Returns a comma separated string with true label and predicted label.
    55    """
    56    def process(self, element: Tuple[int, PredictionResult]) -> Iterable[str]:
    57      label, prediction_result = element
    58      prediction = prediction_result.inference
    59      yield '{},{}'.format(label, prediction)
    60  
    61  
    62  def parse_known_args(argv):
    63    """Parses args for the workflow."""
    64    parser = argparse.ArgumentParser()
    65    parser.add_argument(
    66        '--input',
    67        dest='input',
    68        required=True,
    69        help='text file with comma separated int values.')
    70    parser.add_argument(
    71        '--output',
    72        dest='output',
    73        required=True,
    74        help='Path to save output predictions.')
    75    parser.add_argument(
    76        '--model_path',
    77        dest='model_path',
    78        required=True,
    79        help='Path to load the Sklearn model for Inference.')
    80    return parser.parse_known_args(argv)
    81  
    82  
    83  def run(
    84      argv=None, save_main_session=True, test_pipeline=None) -> PipelineResult:
    85    """
    86    Args:
    87      argv: Command line arguments defined for this example.
    88      save_main_session: Used for internal testing.
    89      test_pipeline: Used for internal testing.
    90    """
    91    known_args, pipeline_args = parse_known_args(argv)
    92    pipeline_options = PipelineOptions(pipeline_args)
    93    pipeline_options.view_as(SetupOptions).save_main_session = save_main_session
    94    requirements_dir = os.path.dirname(os.path.realpath(__file__))
    95    # Pin to the version that we trained the model on.
    96    # Sklearn doesn't guarantee compatability between versions.
    97    pipeline_options.view_as(
    98        SetupOptions
    99    ).requirements_file = f'{requirements_dir}/sklearn_examples_requirements.txt'
   100  
   101    # In this example we pass keyed inputs to RunInference transform.
   102    # Therefore, we use KeyedModelHandler wrapper over SklearnModelHandlerNumpy.
   103    model_loader = KeyedModelHandler(
   104        SklearnModelHandlerNumpy(
   105            model_file_type=ModelFileType.PICKLE,
   106            model_uri=known_args.model_path))
   107  
   108    pipeline = test_pipeline
   109    if not test_pipeline:
   110      pipeline = beam.Pipeline(options=pipeline_options)
   111  
   112    label_pixel_tuple = (
   113        pipeline
   114        | "ReadFromInput" >> beam.io.ReadFromText(known_args.input)
   115        | "PreProcessInputs" >> beam.Map(process_input))
   116  
   117    predictions = (
   118        label_pixel_tuple
   119        | "RunInference" >> RunInference(model_loader)
   120        | "PostProcessOutputs" >> beam.ParDo(PostProcessor()))
   121  
   122    _ = predictions | "WriteOutput" >> beam.io.WriteToText(
   123        known_args.output, shard_name_template='', append_trailing_newlines=True)
   124  
   125    result = pipeline.run()
   126    result.wait_until_finish()
   127    return result
   128  
   129  
   130  if __name__ == '__main__':
   131    logging.getLogger().setLevel(logging.INFO)
   132    run()