github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/examples/inference/pytorch_image_classification.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """A pipeline that uses RunInference API to perform image classification."""
    19  
    20  import argparse
    21  import io
    22  import logging
    23  import os
    24  from typing import Iterator
    25  from typing import Optional
    26  from typing import Tuple
    27  
    28  import apache_beam as beam
    29  import torch
    30  from apache_beam.io.filesystems import FileSystems
    31  from apache_beam.ml.inference.base import KeyedModelHandler
    32  from apache_beam.ml.inference.base import PredictionResult
    33  from apache_beam.ml.inference.base import RunInference
    34  from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor
    35  from apache_beam.options.pipeline_options import PipelineOptions
    36  from apache_beam.options.pipeline_options import SetupOptions
    37  from apache_beam.runners.runner import PipelineResult
    38  from PIL import Image
    39  from torchvision import models
    40  from torchvision import transforms
    41  
    42  
    43  def read_image(image_file_name: str,
    44                 path_to_dir: Optional[str] = None) -> Tuple[str, Image.Image]:
    45    if path_to_dir is not None:
    46      image_file_name = os.path.join(path_to_dir, image_file_name)
    47    with FileSystems().open(image_file_name, 'r') as file:
    48      data = Image.open(io.BytesIO(file.read())).convert('RGB')
    49      return image_file_name, data
    50  
    51  
    52  def preprocess_image(data: Image.Image) -> torch.Tensor:
    53    image_size = (224, 224)
    54    # Pre-trained PyTorch models expect input images normalized with the
    55    # below values (see: https://pytorch.org/vision/stable/models.html)
    56    normalize = transforms.Normalize(
    57        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    58    transform = transforms.Compose([
    59        transforms.Resize(image_size),
    60        transforms.ToTensor(),
    61        normalize,
    62    ])
    63    return transform(data)
    64  
    65  
    66  def filter_empty_lines(text: str) -> Iterator[str]:
    67    if len(text.strip()) > 0:
    68      yield text
    69  
    70  
    71  def parse_known_args(argv):
    72    """Parses args for the workflow."""
    73    parser = argparse.ArgumentParser()
    74    parser.add_argument(
    75        '--input',
    76        dest='input',
    77        required=True,
    78        help='Path to the text file containing image names.')
    79    parser.add_argument(
    80        '--output',
    81        dest='output',
    82        required=True,
    83        help='Path where to save output predictions.'
    84        ' text file.')
    85    parser.add_argument(
    86        '--model_state_dict_path',
    87        dest='model_state_dict_path',
    88        required=True,
    89        help="Path to the model's state_dict.")
    90    parser.add_argument(
    91        '--images_dir',
    92        default=None,
    93        help='Path to the directory where images are stored.'
    94        'Not required if image names in the input file have absolute path.')
    95    return parser.parse_known_args(argv)
    96  
    97  
    98  def run(
    99      argv=None,
   100      model_class=None,
   101      model_params=None,
   102      save_main_session=True,
   103      device='CPU',
   104      test_pipeline=None) -> PipelineResult:
   105    """
   106    Args:
   107      argv: Command line arguments defined for this example.
   108      model_class: Reference to the class definition of the model.
   109      model_params: Parameters passed to the constructor of the model_class.
   110                    These will be used to instantiate the model object in the
   111                    RunInference API.
   112      save_main_session: Used for internal testing.
   113      device: Device to be used on the Runner. Choices are (CPU, GPU).
   114      test_pipeline: Used for internal testing.
   115    """
   116    known_args, pipeline_args = parse_known_args(argv)
   117    pipeline_options = PipelineOptions(pipeline_args)
   118    pipeline_options.view_as(SetupOptions).save_main_session = save_main_session
   119  
   120    if not model_class:
   121      # default model class will be mobilenet with pretrained weights.
   122      model_class = models.mobilenet_v2
   123      model_params = {'num_classes': 1000}
   124  
   125    def preprocess(image_name: str) -> Tuple[str, torch.Tensor]:
   126      image_name, image = read_image(
   127        image_file_name=image_name,
   128        path_to_dir=known_args.images_dir)
   129      return (image_name, preprocess_image(image))
   130  
   131    def postprocess(element: Tuple[str, PredictionResult]) -> str:
   132      filename, prediction_result = element
   133      prediction = torch.argmax(prediction_result.inference, dim=0)
   134      return filename + ',' + str(prediction.item())
   135  
   136    # In this example we pass keyed inputs to RunInference transform.
   137    # Therefore, we use KeyedModelHandler wrapper over PytorchModelHandler.
   138    model_handler = KeyedModelHandler(
   139        PytorchModelHandlerTensor(
   140            state_dict_path=known_args.model_state_dict_path,
   141            model_class=model_class,
   142            model_params=model_params,
   143            device=device,
   144            min_batch_size=10,
   145            max_batch_size=100)).with_preprocess_fn(
   146                preprocess).with_postprocess_fn(postprocess)
   147  
   148    pipeline = test_pipeline
   149    if not test_pipeline:
   150      pipeline = beam.Pipeline(options=pipeline_options)
   151  
   152    filename_value_pair = (
   153        pipeline
   154        | 'ReadImageNames' >> beam.io.ReadFromText(known_args.input)
   155        | 'FilterEmptyLines' >> beam.ParDo(filter_empty_lines))
   156    predictions = (
   157        filename_value_pair
   158        | 'PyTorchRunInference' >> RunInference(model_handler))
   159  
   160    predictions | "WriteOutputToGCS" >> beam.io.WriteToText( # pylint: disable=expression-not-assigned
   161      known_args.output,
   162      shard_name_template='',
   163      append_trailing_newlines=True)
   164  
   165    result = pipeline.run()
   166    result.wait_until_finish()
   167    return result
   168  
   169  
   170  if __name__ == '__main__':
   171    logging.getLogger().setLevel(logging.INFO)
   172    run()