github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/examples/inference/tensorflow_mnist_classification.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  import argparse
    19  import logging
    20  from typing import Iterable
    21  from typing import Tuple
    22  
    23  import numpy
    24  
    25  import apache_beam as beam
    26  from apache_beam.ml.inference.base import KeyedModelHandler
    27  from apache_beam.ml.inference.base import PredictionResult
    28  from apache_beam.ml.inference.base import RunInference
    29  from apache_beam.ml.inference.tensorflow_inference import ModelType
    30  from apache_beam.ml.inference.tensorflow_inference import TFModelHandlerNumpy
    31  from apache_beam.options.pipeline_options import PipelineOptions
    32  from apache_beam.options.pipeline_options import SetupOptions
    33  from apache_beam.runners.runner import PipelineResult
    34  
    35  
    36  def process_input(row: str) -> Tuple[int, numpy.ndarray]:
    37    data = row.split(',')
    38    label, pixels = int(data[0]), data[1:]
    39    pixels = [int(pixel) for pixel in pixels]
    40    # the trained model accepts the input of shape 28x28
    41    pixels = numpy.array(pixels).reshape((28, 28, 1))
    42    return label, pixels
    43  
    44  
    45  class PostProcessor(beam.DoFn):
    46    """Process the PredictionResult to get the predicted label.
    47    Returns a comma separated string with true label and predicted label.
    48    """
    49    def process(self, element: Tuple[int, PredictionResult]) -> Iterable[str]:
    50      label, prediction_result = element
    51      prediction = numpy.argmax(prediction_result.inference, axis=0)
    52      yield '{},{}'.format(label, prediction)
    53  
    54  
    55  def parse_known_args(argv):
    56    """Parses args for the workflow."""
    57    parser = argparse.ArgumentParser()
    58    parser.add_argument(
    59        '--input',
    60        dest='input',
    61        required=True,
    62        help='text file with comma separated int values.')
    63    parser.add_argument(
    64        '--output',
    65        dest='output',
    66        required=True,
    67        help='Path to save output predictions.')
    68    parser.add_argument(
    69        '--model_path',
    70        dest='model_path',
    71        required=True,
    72        help='Path to load the Tensorflow model for Inference.')
    73    return parser.parse_known_args(argv)
    74  
    75  
    76  def run(
    77      argv=None, save_main_session=True, test_pipeline=None) -> PipelineResult:
    78    """
    79    Args:
    80      argv: Command line arguments defined for this example.
    81      save_main_session: Used for internal testing.
    82      test_pipeline: Used for internal testing.
    83    """
    84    known_args, pipeline_args = parse_known_args(argv)
    85    pipeline_options = PipelineOptions(pipeline_args)
    86    pipeline_options.view_as(SetupOptions).save_main_session = save_main_session
    87  
    88    # In this example we pass keyed inputs to RunInference transform.
    89    # Therefore, we use KeyedModelHandler wrapper over TFModelHandlerNumpy.
    90    model_loader = KeyedModelHandler(
    91        TFModelHandlerNumpy(
    92            model_uri=known_args.model_path, model_type=ModelType.SAVED_MODEL))
    93  
    94    pipeline = test_pipeline
    95    if not test_pipeline:
    96      pipeline = beam.Pipeline(options=pipeline_options)
    97  
    98    label_pixel_tuple = (
    99        pipeline
   100        | "ReadFromInput" >> beam.io.ReadFromText(known_args.input)
   101        | "PreProcessInputs" >> beam.Map(process_input))
   102  
   103    predictions = (
   104        label_pixel_tuple
   105        | "RunInference" >> RunInference(model_loader)
   106        | "PostProcessOutputs" >> beam.ParDo(PostProcessor()))
   107  
   108    _ = predictions | "WriteOutput" >> beam.io.WriteToText(
   109        known_args.output, shard_name_template='', append_trailing_newlines=True)
   110  
   111    result = pipeline.run()
   112    result.wait_until_finish()
   113    return result
   114  
   115  
   116  if __name__ == '__main__':
   117    logging.getLogger().setLevel(logging.INFO)
   118    run()