github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/examples/inference/pytorch_image_segmentation.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """A pipeline that uses RunInference API to perform image segmentation."""
    19  
    20  import argparse
    21  import io
    22  import logging
    23  import os
    24  from typing import Iterable
    25  from typing import Iterator
    26  from typing import Optional
    27  from typing import Tuple
    28  
    29  import apache_beam as beam
    30  import torch
    31  from apache_beam.io.filesystems import FileSystems
    32  from apache_beam.ml.inference.base import KeyedModelHandler
    33  from apache_beam.ml.inference.base import PredictionResult
    34  from apache_beam.ml.inference.base import RunInference
    35  from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor
    36  from apache_beam.options.pipeline_options import PipelineOptions
    37  from apache_beam.options.pipeline_options import SetupOptions
    38  from apache_beam.runners.runner import PipelineResult
    39  from PIL import Image
    40  from torchvision import transforms
    41  from torchvision.models.detection import maskrcnn_resnet50_fpn
    42  
    43  COCO_INSTANCE_CLASSES = [
    44      '__background__',
    45      'person',
    46      'bicycle',
    47      'car',
    48      'motorcycle',
    49      'airplane',
    50      'bus',
    51      'train',
    52      'truck',
    53      'boat',
    54      'traffic light',
    55      'fire hydrant',
    56      'N/A',
    57      'stop sign',
    58      'parking meter',
    59      'bench',
    60      'bird',
    61      'cat',
    62      'dog',
    63      'horse',
    64      'sheep',
    65      'cow',
    66      'elephant',
    67      'bear',
    68      'zebra',
    69      'giraffe',
    70      'N/A',
    71      'backpack',
    72      'umbrella',
    73      'N/A',
    74      'N/A',
    75      'handbag',
    76      'tie',
    77      'suitcase',
    78      'frisbee',
    79      'skis',
    80      'snowboard',
    81      'sports ball',
    82      'kite',
    83      'baseball bat',
    84      'baseball glove',
    85      'skateboard',
    86      'surfboard',
    87      'tennis racket',
    88      'bottle',
    89      'N/A',
    90      'wine glass',
    91      'cup',
    92      'fork',
    93      'knife',
    94      'spoon',
    95      'bowl',
    96      'banana',
    97      'apple',
    98      'sandwich',
    99      'orange',
   100      'broccoli',
   101      'carrot',
   102      'hot dog',
   103      'pizza',
   104      'donut',
   105      'cake',
   106      'chair',
   107      'couch',
   108      'potted plant',
   109      'bed',
   110      'N/A',
   111      'dining table',
   112      'N/A',
   113      'N/A',
   114      'toilet',
   115      'N/A',
   116      'tv',
   117      'laptop',
   118      'mouse',
   119      'remote',
   120      'keyboard',
   121      'cell phone',
   122      'microwave',
   123      'oven',
   124      'toaster',
   125      'sink',
   126      'refrigerator',
   127      'N/A',
   128      'book',
   129      'clock',
   130      'vase',
   131      'scissors',
   132      'teddy bear',
   133      'hair drier',
   134      'toothbrush'
   135  ]
   136  
   137  CLASS_ID_TO_NAME = dict(enumerate(COCO_INSTANCE_CLASSES))
   138  
   139  
   140  def read_image(image_file_name: str,
   141                 path_to_dir: Optional[str] = None) -> Tuple[str, Image.Image]:
   142    if path_to_dir is not None:
   143      image_file_name = os.path.join(path_to_dir, image_file_name)
   144    with FileSystems().open(image_file_name, 'r') as file:
   145      data = Image.open(io.BytesIO(file.read())).convert('RGB')
   146      return image_file_name, data
   147  
   148  
   149  def preprocess_image(data: Image.Image) -> torch.Tensor:
   150    image_size = (224, 224)
   151    transform = transforms.Compose([
   152        transforms.Resize(image_size),
   153        transforms.ToTensor(),
   154    ])
   155    return transform(data)
   156  
   157  
   158  def filter_empty_lines(text: str) -> Iterator[str]:
   159    if len(text.strip()) > 0:
   160      yield text
   161  
   162  
   163  class PostProcessor(beam.DoFn):
   164    def process(self, element: Tuple[str, PredictionResult]) -> Iterable[str]:
   165      filename, prediction_result = element
   166      prediction_labels = prediction_result.inference['labels']
   167      classes = [CLASS_ID_TO_NAME[label.item()] for label in prediction_labels]
   168      yield filename + ';' + str(classes)
   169  
   170  
   171  def parse_known_args(argv):
   172    """Parses args for the workflow."""
   173    parser = argparse.ArgumentParser()
   174    parser.add_argument(
   175        '--input',
   176        dest='input',
   177        required=True,
   178        help='Path to the text file containing image names.')
   179    parser.add_argument(
   180        '--output',
   181        dest='output',
   182        required=True,
   183        help='Path where to save output predictions.'
   184        ' text file.')
   185    parser.add_argument(
   186        '--model_state_dict_path',
   187        dest='model_state_dict_path',
   188        required=True,
   189        help="Path to the model's state_dict. "
   190        "Default state_dict would be maskrcnn_resnet50_fpn.")
   191    parser.add_argument(
   192        '--images_dir',
   193        help='Path to the directory where images are stored.'
   194        'Not required if image names in the input file have absolute path.')
   195    return parser.parse_known_args(argv)
   196  
   197  
   198  def run(
   199      argv=None,
   200      model_class=None,
   201      model_params=None,
   202      save_main_session=True,
   203      test_pipeline=None) -> PipelineResult:
   204    """
   205    Args:
   206      argv: Command line arguments defined for this example.
   207      model_class: Reference to the class definition of the model.
   208                  If None, maskrcnn_resnet50_fpn will be used as default .
   209      model_params: Parameters passed to the constructor of the model_class.
   210                    These will be used to instantiate the model object in the
   211                    RunInference API.
   212      save_main_session: Used for internal testing.
   213      test_pipeline: Used for internal testing.
   214    """
   215    known_args, pipeline_args = parse_known_args(argv)
   216    pipeline_options = PipelineOptions(pipeline_args)
   217    pipeline_options.view_as(SetupOptions).save_main_session = save_main_session
   218  
   219    if not model_class:
   220      model_class = maskrcnn_resnet50_fpn
   221      model_params = {'num_classes': 91}
   222  
   223    model_handler = PytorchModelHandlerTensor(
   224        state_dict_path=known_args.model_state_dict_path,
   225        model_class=model_class,
   226        model_params=model_params)
   227  
   228    pipeline = test_pipeline
   229    if not test_pipeline:
   230      pipeline = beam.Pipeline(options=pipeline_options)
   231  
   232    filename_value_pair = (
   233        pipeline
   234        | 'ReadImageNames' >> beam.io.ReadFromText(known_args.input)
   235        | 'FilterEmptyLines' >> beam.ParDo(filter_empty_lines)
   236        | 'ReadImageData' >> beam.Map(
   237            lambda image_name: read_image(
   238                image_file_name=image_name, path_to_dir=known_args.images_dir))
   239        | 'PreprocessImages' >> beam.MapTuple(
   240            lambda file_name, data: (file_name, preprocess_image(data))))
   241    predictions = (
   242        filename_value_pair
   243        | 'PyTorchRunInference' >> RunInference(KeyedModelHandler(model_handler))
   244        | 'ProcessOutput' >> beam.ParDo(PostProcessor()))
   245  
   246    _ = predictions | "WriteOutput" >> beam.io.WriteToText(
   247        known_args.output, shard_name_template='', append_trailing_newlines=True)
   248  
   249    result = pipeline.run()
   250    result.wait_until_finish()
   251    return result
   252  
   253  
   254  if __name__ == '__main__':
   255    logging.getLogger().setLevel(logging.INFO)
   256    run()