github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/examples/inference/tensorflow_imagenet_segmentation.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  import argparse
    19  import logging
    20  from typing import Iterable
    21  from typing import Iterator
    22  
    23  import numpy
    24  
    25  import apache_beam as beam
    26  import tensorflow as tf
    27  from apache_beam.ml.inference.base import PredictionResult
    28  from apache_beam.ml.inference.base import RunInference
    29  from apache_beam.ml.inference.tensorflow_inference import TFModelHandlerTensor
    30  from apache_beam.options.pipeline_options import PipelineOptions
    31  from apache_beam.options.pipeline_options import SetupOptions
    32  from apache_beam.runners.runner import PipelineResult
    33  from PIL import Image
    34  
    35  
    36  class PostProcessor(beam.DoFn):
    37    """Process the PredictionResult to get the predicted label.
    38    Returns predicted label.
    39    """
    40    def setup(self):
    41      labels_path = tf.keras.utils.get_file(
    42          'ImageNetLabels.txt',
    43          'https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt'  # pylint: disable=line-too-long
    44      )
    45      self._imagenet_labels = numpy.array(open(labels_path).read().splitlines())
    46  
    47    def process(self, element: PredictionResult) -> Iterable[str]:
    48      predicted_class = numpy.argmax(element.inference, axis=-1)
    49      predicted_class_name = self._imagenet_labels[predicted_class]
    50      yield predicted_class_name.title()
    51  
    52  
    53  def parse_known_args(argv):
    54    """Parses args for the workflow."""
    55    parser = argparse.ArgumentParser()
    56    parser.add_argument(
    57        '--input',
    58        dest='input',
    59        required=True,
    60        help='Path to the text file containing image names.')
    61    parser.add_argument(
    62        '--output',
    63        dest='output',
    64        required=True,
    65        help='Path to save output predictions.')
    66    parser.add_argument(
    67        '--model_path',
    68        dest='model_path',
    69        required=True,
    70        help='Path to load the Tensorflow model for Inference.')
    71    parser.add_argument(
    72        '--image_dir', help='Path to the directory where images are stored.')
    73    return parser.parse_known_args(argv)
    74  
    75  
    76  def filter_empty_lines(text: str) -> Iterator[str]:
    77    if len(text.strip()) > 0:
    78      yield text
    79  
    80  
    81  def read_image(image_name, image_dir):
    82    img = tf.keras.utils.get_file(image_name, image_dir + image_name)
    83    img = Image.open(img).resize((224, 224))
    84    img = numpy.array(img) / 255.0
    85    img_tensor = tf.cast(tf.convert_to_tensor(img[...]), dtype=tf.float32)
    86    return img_tensor
    87  
    88  
    89  def run(
    90      argv=None, save_main_session=True, test_pipeline=None) -> PipelineResult:
    91    """
    92    Args:
    93      argv: Command line arguments defined for this example.
    94      save_main_session: Used for internal testing.
    95      test_pipeline: Used for internal testing.
    96    """
    97    known_args, pipeline_args = parse_known_args(argv)
    98    pipeline_options = PipelineOptions(pipeline_args)
    99    pipeline_options.view_as(SetupOptions).save_main_session = save_main_session
   100  
   101    # In this example we will use the TensorflowHub model URL.
   102    model_loader = TFModelHandlerTensor(
   103        model_uri=known_args.model_path).with_preprocess_fn(
   104            lambda image_name: read_image(image_name, known_args.image_dir))
   105  
   106    pipeline = test_pipeline
   107    if not test_pipeline:
   108      pipeline = beam.Pipeline(options=pipeline_options)
   109  
   110    image = (
   111        pipeline
   112        | 'ReadImageNames' >> beam.io.ReadFromText(known_args.input)
   113        | 'FilterEmptyLines' >> beam.ParDo(filter_empty_lines))
   114  
   115    predictions = (
   116        image
   117        | "RunInference" >> RunInference(model_loader)
   118        | "PostProcessOutputs" >> beam.ParDo(PostProcessor()))
   119  
   120    _ = predictions | "WriteOutput" >> beam.io.WriteToText(
   121        known_args.output, shard_name_template='', append_trailing_newlines=True)
   122  
   123    result = pipeline.run()
   124    result.wait_until_finish()
   125    return result
   126  
   127  
   128  if __name__ == '__main__':
   129    logging.getLogger().setLevel(logging.INFO)
   130    run()