github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/examples/inference/tensorrt_object_detection.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """A pipeline that uses RunInference API to perform object detection with 19 TensorRT. 20 """ 21 22 import argparse 23 import io 24 import os 25 from typing import Iterable 26 from typing import Optional 27 from typing import Tuple 28 29 import numpy as np 30 31 import apache_beam as beam 32 from apache_beam.io.filesystems import FileSystems 33 from apache_beam.ml.inference.base import KeyedModelHandler 34 from apache_beam.ml.inference.base import PredictionResult 35 from apache_beam.ml.inference.base import RunInference 36 from apache_beam.ml.inference.tensorrt_inference import TensorRTEngineHandlerNumPy # pylint: disable=line-too-long 37 from apache_beam.options.pipeline_options import PipelineOptions 38 from apache_beam.options.pipeline_options import SetupOptions 39 from PIL import Image 40 41 COCO_OBJ_DET_CLASSES = [ 42 'person', 43 'bicycle', 44 'car', 45 'motorcycle', 46 'airplane', 47 'bus', 48 'train', 49 'truck', 50 'boat', 51 'traffic light', 52 'fire hydrant', 53 'street sign', 54 'stop sign', 55 'parking meter', 56 'bench', 57 'bird', 58 'cat', 59 'dog', 60 'horse', 61 'sheep', 62 'cow', 63 'elephant', 64 'bear', 65 'zebra', 66 'giraffe', 67 'hat', 68 'backpack', 69 'umbrella', 70 'shoe', 71 'eye glasses', 72 'handbag', 73 'tie', 74 'suitcase', 75 'frisbee', 76 'skis', 77 'snowboard', 78 'sports ball', 79 'kite', 80 'baseball bat', 81 'baseball glove', 82 'skateboard', 83 'surfboard', 84 'tennis racket', 85 'bottle', 86 'plate', 87 'wine glass', 88 'cup', 89 'fork', 90 'knife', 91 'spoon', 92 'bowl', 93 'banana', 94 'apple', 95 'sandwich', 96 'orange', 97 'broccoli', 98 'carrot', 99 'hot dog', 100 'pizza', 101 'donut', 102 'cake', 103 'chair', 104 'couch', 105 'potted plant', 106 'bed', 107 'mirror', 108 'dining table', 109 'window', 110 'desk', 111 'toilet', 112 'door', 113 'tv', 114 'laptop', 115 'mouse', 116 'remote', 117 'keyboard', 118 'cell phone', 119 'microwave', 120 'oven', 121 'toaster', 122 'sink', 123 'refrigerator', 124 'blender', 125 'book', 126 'clock', 127 'vase', 128 'scissors', 129 'teddy bear', 130 'hair drier', 131 'toothbrush', 132 'hair brush', 133 ] 134 135 136 def attach_im_size_to_key( 137 data: Tuple[str, Image.Image]) -> Tuple[Tuple[str, int, int], Image.Image]: 138 filename, image = data 139 width, height = image.size 140 return ((filename, width, height), image) 141 142 143 def read_image(image_file_name: str, 144 path_to_dir: Optional[str] = None) -> Tuple[str, Image.Image]: 145 if path_to_dir is not None: 146 image_file_name = os.path.join(path_to_dir, image_file_name) 147 with FileSystems().open(image_file_name, 'r') as file: 148 data = Image.open(io.BytesIO(file.read())).convert('RGB') 149 return image_file_name, data 150 151 152 def preprocess_image(image: Image.Image) -> np.ndarray: 153 ssd_mobilenet_v2_320x320_input_dims = (300, 300) 154 image = image.resize( 155 ssd_mobilenet_v2_320x320_input_dims, resample=Image.Resampling.BILINEAR) 156 image = np.expand_dims(np.asarray(image, dtype=np.float32), axis=0) 157 return image 158 159 160 class PostProcessor(beam.DoFn): 161 """Processes the PredictionResult that consists of 162 number of detections per image, box coordinates, scores and classes. 163 164 We loop over all detections to organize attributes on a per 165 detection basis. Box coordinates are normalized, hence we have to scale them 166 according to original image dimensions. Score is a floating point number 167 that provides probability percentage of a particular object. Class is 168 an integer that we can transform into actual string class using 169 COCO_OBJ_DET_CLASSES as reference. 170 """ 171 def process(self, element: Tuple[str, PredictionResult]) -> Iterable[str]: 172 key, prediction_result = element 173 filename, im_width, im_height = key 174 num_detections = prediction_result.inference[0] 175 boxes = prediction_result.inference[1] 176 scores = prediction_result.inference[2] 177 classes = prediction_result.inference[3] 178 detections = [] 179 for i in range(int(num_detections[0])): 180 detections.append({ 181 'ymin': str(boxes[i][0] * im_height), 182 'xmin': str(boxes[i][1] * im_width), 183 'ymax': str(boxes[i][2] * im_height), 184 'xmax': str(boxes[i][3] * im_width), 185 'score': str(scores[i]), 186 'class': COCO_OBJ_DET_CLASSES[int(classes[i])] 187 }) 188 yield filename + ',' + str(detections) 189 190 191 def parse_known_args(argv): 192 """Parses args for the workflow.""" 193 parser = argparse.ArgumentParser() 194 parser.add_argument( 195 '--input', 196 dest='input', 197 required=True, 198 help='Path to the text file containing image names.') 199 parser.add_argument( 200 '--output', 201 dest='output', 202 required=True, 203 help='Path where to save output predictions.' 204 ' text file.') 205 parser.add_argument( 206 '--engine_path', 207 dest='engine_path', 208 required=True, 209 help='Path to the pre-built TFOD ssd_mobilenet_v2_320x320_coco17_tpu-8' 210 'TensorRT engine.') 211 parser.add_argument( 212 '--images_dir', 213 default=None, 214 help='Path to the directory where images are stored.' 215 'Not required if image names in the input file have absolute path.') 216 return parser.parse_known_args(argv) 217 218 219 def run(argv=None, save_main_session=True): 220 """ 221 Args: 222 argv: Command line arguments defined for this example. 223 """ 224 known_args, pipeline_args = parse_known_args(argv) 225 pipeline_options = PipelineOptions(pipeline_args) 226 pipeline_options.view_as(SetupOptions).save_main_session = save_main_session 227 228 engine_handler = KeyedModelHandler( 229 TensorRTEngineHandlerNumPy( 230 min_batch_size=1, 231 max_batch_size=1, 232 engine_path=known_args.engine_path)) 233 234 with beam.Pipeline(options=pipeline_options) as p: 235 filename_value_pair = ( 236 p 237 | 'ReadImageNames' >> beam.io.ReadFromText(known_args.input) 238 | 'ReadImageData' >> beam.Map( 239 lambda image_name: read_image( 240 image_file_name=image_name, path_to_dir=known_args.images_dir)) 241 | 'AttachImageSizeToKey' >> beam.Map(attach_im_size_to_key) 242 | 'PreprocessImages' >> beam.MapTuple( 243 lambda file_name, data: (file_name, preprocess_image(data)))) 244 predictions = ( 245 filename_value_pair 246 | 'TensorRTRunInference' >> RunInference(engine_handler) 247 | 'ProcessOutput' >> beam.ParDo(PostProcessor())) 248 249 _ = ( 250 predictions | "WriteOutputToGCS" >> beam.io.WriteToText( 251 known_args.output, 252 shard_name_template='', 253 append_trailing_newlines=True)) 254 255 256 if __name__ == '__main__': 257 run()