github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/examples/inference/tensorrt_object_detection.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """A pipeline that uses RunInference API to perform object detection with
    19  TensorRT.
    20  """
    21  
    22  import argparse
    23  import io
    24  import os
    25  from typing import Iterable
    26  from typing import Optional
    27  from typing import Tuple
    28  
    29  import numpy as np
    30  
    31  import apache_beam as beam
    32  from apache_beam.io.filesystems import FileSystems
    33  from apache_beam.ml.inference.base import KeyedModelHandler
    34  from apache_beam.ml.inference.base import PredictionResult
    35  from apache_beam.ml.inference.base import RunInference
    36  from apache_beam.ml.inference.tensorrt_inference import TensorRTEngineHandlerNumPy  # pylint: disable=line-too-long
    37  from apache_beam.options.pipeline_options import PipelineOptions
    38  from apache_beam.options.pipeline_options import SetupOptions
    39  from PIL import Image
    40  
    41  COCO_OBJ_DET_CLASSES = [
    42      'person',
    43      'bicycle',
    44      'car',
    45      'motorcycle',
    46      'airplane',
    47      'bus',
    48      'train',
    49      'truck',
    50      'boat',
    51      'traffic light',
    52      'fire hydrant',
    53      'street sign',
    54      'stop sign',
    55      'parking meter',
    56      'bench',
    57      'bird',
    58      'cat',
    59      'dog',
    60      'horse',
    61      'sheep',
    62      'cow',
    63      'elephant',
    64      'bear',
    65      'zebra',
    66      'giraffe',
    67      'hat',
    68      'backpack',
    69      'umbrella',
    70      'shoe',
    71      'eye glasses',
    72      'handbag',
    73      'tie',
    74      'suitcase',
    75      'frisbee',
    76      'skis',
    77      'snowboard',
    78      'sports ball',
    79      'kite',
    80      'baseball bat',
    81      'baseball glove',
    82      'skateboard',
    83      'surfboard',
    84      'tennis racket',
    85      'bottle',
    86      'plate',
    87      'wine glass',
    88      'cup',
    89      'fork',
    90      'knife',
    91      'spoon',
    92      'bowl',
    93      'banana',
    94      'apple',
    95      'sandwich',
    96      'orange',
    97      'broccoli',
    98      'carrot',
    99      'hot dog',
   100      'pizza',
   101      'donut',
   102      'cake',
   103      'chair',
   104      'couch',
   105      'potted plant',
   106      'bed',
   107      'mirror',
   108      'dining table',
   109      'window',
   110      'desk',
   111      'toilet',
   112      'door',
   113      'tv',
   114      'laptop',
   115      'mouse',
   116      'remote',
   117      'keyboard',
   118      'cell phone',
   119      'microwave',
   120      'oven',
   121      'toaster',
   122      'sink',
   123      'refrigerator',
   124      'blender',
   125      'book',
   126      'clock',
   127      'vase',
   128      'scissors',
   129      'teddy bear',
   130      'hair drier',
   131      'toothbrush',
   132      'hair brush',
   133  ]
   134  
   135  
   136  def attach_im_size_to_key(
   137      data: Tuple[str, Image.Image]) -> Tuple[Tuple[str, int, int], Image.Image]:
   138    filename, image = data
   139    width, height = image.size
   140    return ((filename, width, height), image)
   141  
   142  
   143  def read_image(image_file_name: str,
   144                 path_to_dir: Optional[str] = None) -> Tuple[str, Image.Image]:
   145    if path_to_dir is not None:
   146      image_file_name = os.path.join(path_to_dir, image_file_name)
   147    with FileSystems().open(image_file_name, 'r') as file:
   148      data = Image.open(io.BytesIO(file.read())).convert('RGB')
   149      return image_file_name, data
   150  
   151  
   152  def preprocess_image(image: Image.Image) -> np.ndarray:
   153    ssd_mobilenet_v2_320x320_input_dims = (300, 300)
   154    image = image.resize(
   155        ssd_mobilenet_v2_320x320_input_dims, resample=Image.Resampling.BILINEAR)
   156    image = np.expand_dims(np.asarray(image, dtype=np.float32), axis=0)
   157    return image
   158  
   159  
   160  class PostProcessor(beam.DoFn):
   161    """Processes the PredictionResult that consists of
   162    number of detections per image, box coordinates, scores and classes.
   163  
   164    We loop over all detections to organize attributes on a per
   165    detection basis. Box coordinates are normalized, hence we have to scale them
   166    according to original image dimensions. Score is a floating point number
   167    that provides probability percentage of a particular object. Class is
   168    an integer that we can transform into actual string class using
   169    COCO_OBJ_DET_CLASSES as reference.
   170    """
   171    def process(self, element: Tuple[str, PredictionResult]) -> Iterable[str]:
   172      key, prediction_result = element
   173      filename, im_width, im_height = key
   174      num_detections = prediction_result.inference[0]
   175      boxes = prediction_result.inference[1]
   176      scores = prediction_result.inference[2]
   177      classes = prediction_result.inference[3]
   178      detections = []
   179      for i in range(int(num_detections[0])):
   180        detections.append({
   181            'ymin': str(boxes[i][0] * im_height),
   182            'xmin': str(boxes[i][1] * im_width),
   183            'ymax': str(boxes[i][2] * im_height),
   184            'xmax': str(boxes[i][3] * im_width),
   185            'score': str(scores[i]),
   186            'class': COCO_OBJ_DET_CLASSES[int(classes[i])]
   187        })
   188      yield filename + ',' + str(detections)
   189  
   190  
   191  def parse_known_args(argv):
   192    """Parses args for the workflow."""
   193    parser = argparse.ArgumentParser()
   194    parser.add_argument(
   195        '--input',
   196        dest='input',
   197        required=True,
   198        help='Path to the text file containing image names.')
   199    parser.add_argument(
   200        '--output',
   201        dest='output',
   202        required=True,
   203        help='Path where to save output predictions.'
   204        ' text file.')
   205    parser.add_argument(
   206        '--engine_path',
   207        dest='engine_path',
   208        required=True,
   209        help='Path to the pre-built TFOD ssd_mobilenet_v2_320x320_coco17_tpu-8'
   210        'TensorRT engine.')
   211    parser.add_argument(
   212        '--images_dir',
   213        default=None,
   214        help='Path to the directory where images are stored.'
   215        'Not required if image names in the input file have absolute path.')
   216    return parser.parse_known_args(argv)
   217  
   218  
   219  def run(argv=None, save_main_session=True):
   220    """
   221    Args:
   222      argv: Command line arguments defined for this example.
   223    """
   224    known_args, pipeline_args = parse_known_args(argv)
   225    pipeline_options = PipelineOptions(pipeline_args)
   226    pipeline_options.view_as(SetupOptions).save_main_session = save_main_session
   227  
   228    engine_handler = KeyedModelHandler(
   229        TensorRTEngineHandlerNumPy(
   230            min_batch_size=1,
   231            max_batch_size=1,
   232            engine_path=known_args.engine_path))
   233  
   234    with beam.Pipeline(options=pipeline_options) as p:
   235      filename_value_pair = (
   236          p
   237          | 'ReadImageNames' >> beam.io.ReadFromText(known_args.input)
   238          | 'ReadImageData' >> beam.Map(
   239              lambda image_name: read_image(
   240                  image_file_name=image_name, path_to_dir=known_args.images_dir))
   241          | 'AttachImageSizeToKey' >> beam.Map(attach_im_size_to_key)
   242          | 'PreprocessImages' >> beam.MapTuple(
   243              lambda file_name, data: (file_name, preprocess_image(data))))
   244      predictions = (
   245          filename_value_pair
   246          | 'TensorRTRunInference' >> RunInference(engine_handler)
   247          | 'ProcessOutput' >> beam.ParDo(PostProcessor()))
   248  
   249      _ = (
   250          predictions | "WriteOutputToGCS" >> beam.io.WriteToText(
   251              known_args.output,
   252              shard_name_template='',
   253              append_trailing_newlines=True))
   254  
   255  
   256  if __name__ == '__main__':
   257    run()