github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/examples/inference/pytorch_image_segmentation.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """A pipeline that uses RunInference API to perform image segmentation.""" 19 20 import argparse 21 import io 22 import logging 23 import os 24 from typing import Iterable 25 from typing import Iterator 26 from typing import Optional 27 from typing import Tuple 28 29 import apache_beam as beam 30 import torch 31 from apache_beam.io.filesystems import FileSystems 32 from apache_beam.ml.inference.base import KeyedModelHandler 33 from apache_beam.ml.inference.base import PredictionResult 34 from apache_beam.ml.inference.base import RunInference 35 from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor 36 from apache_beam.options.pipeline_options import PipelineOptions 37 from apache_beam.options.pipeline_options import SetupOptions 38 from apache_beam.runners.runner import PipelineResult 39 from PIL import Image 40 from torchvision import transforms 41 from torchvision.models.detection import maskrcnn_resnet50_fpn 42 43 COCO_INSTANCE_CLASSES = [ 44 '__background__', 45 'person', 46 'bicycle', 47 'car', 48 'motorcycle', 49 'airplane', 50 'bus', 51 'train', 52 'truck', 53 'boat', 54 'traffic light', 55 'fire hydrant', 56 'N/A', 57 'stop sign', 58 'parking meter', 59 'bench', 60 'bird', 61 'cat', 62 'dog', 63 'horse', 64 'sheep', 65 'cow', 66 'elephant', 67 'bear', 68 'zebra', 69 'giraffe', 70 'N/A', 71 'backpack', 72 'umbrella', 73 'N/A', 74 'N/A', 75 'handbag', 76 'tie', 77 'suitcase', 78 'frisbee', 79 'skis', 80 'snowboard', 81 'sports ball', 82 'kite', 83 'baseball bat', 84 'baseball glove', 85 'skateboard', 86 'surfboard', 87 'tennis racket', 88 'bottle', 89 'N/A', 90 'wine glass', 91 'cup', 92 'fork', 93 'knife', 94 'spoon', 95 'bowl', 96 'banana', 97 'apple', 98 'sandwich', 99 'orange', 100 'broccoli', 101 'carrot', 102 'hot dog', 103 'pizza', 104 'donut', 105 'cake', 106 'chair', 107 'couch', 108 'potted plant', 109 'bed', 110 'N/A', 111 'dining table', 112 'N/A', 113 'N/A', 114 'toilet', 115 'N/A', 116 'tv', 117 'laptop', 118 'mouse', 119 'remote', 120 'keyboard', 121 'cell phone', 122 'microwave', 123 'oven', 124 'toaster', 125 'sink', 126 'refrigerator', 127 'N/A', 128 'book', 129 'clock', 130 'vase', 131 'scissors', 132 'teddy bear', 133 'hair drier', 134 'toothbrush' 135 ] 136 137 CLASS_ID_TO_NAME = dict(enumerate(COCO_INSTANCE_CLASSES)) 138 139 140 def read_image(image_file_name: str, 141 path_to_dir: Optional[str] = None) -> Tuple[str, Image.Image]: 142 if path_to_dir is not None: 143 image_file_name = os.path.join(path_to_dir, image_file_name) 144 with FileSystems().open(image_file_name, 'r') as file: 145 data = Image.open(io.BytesIO(file.read())).convert('RGB') 146 return image_file_name, data 147 148 149 def preprocess_image(data: Image.Image) -> torch.Tensor: 150 image_size = (224, 224) 151 transform = transforms.Compose([ 152 transforms.Resize(image_size), 153 transforms.ToTensor(), 154 ]) 155 return transform(data) 156 157 158 def filter_empty_lines(text: str) -> Iterator[str]: 159 if len(text.strip()) > 0: 160 yield text 161 162 163 class PostProcessor(beam.DoFn): 164 def process(self, element: Tuple[str, PredictionResult]) -> Iterable[str]: 165 filename, prediction_result = element 166 prediction_labels = prediction_result.inference['labels'] 167 classes = [CLASS_ID_TO_NAME[label.item()] for label in prediction_labels] 168 yield filename + ';' + str(classes) 169 170 171 def parse_known_args(argv): 172 """Parses args for the workflow.""" 173 parser = argparse.ArgumentParser() 174 parser.add_argument( 175 '--input', 176 dest='input', 177 required=True, 178 help='Path to the text file containing image names.') 179 parser.add_argument( 180 '--output', 181 dest='output', 182 required=True, 183 help='Path where to save output predictions.' 184 ' text file.') 185 parser.add_argument( 186 '--model_state_dict_path', 187 dest='model_state_dict_path', 188 required=True, 189 help="Path to the model's state_dict. " 190 "Default state_dict would be maskrcnn_resnet50_fpn.") 191 parser.add_argument( 192 '--images_dir', 193 help='Path to the directory where images are stored.' 194 'Not required if image names in the input file have absolute path.') 195 return parser.parse_known_args(argv) 196 197 198 def run( 199 argv=None, 200 model_class=None, 201 model_params=None, 202 save_main_session=True, 203 test_pipeline=None) -> PipelineResult: 204 """ 205 Args: 206 argv: Command line arguments defined for this example. 207 model_class: Reference to the class definition of the model. 208 If None, maskrcnn_resnet50_fpn will be used as default . 209 model_params: Parameters passed to the constructor of the model_class. 210 These will be used to instantiate the model object in the 211 RunInference API. 212 save_main_session: Used for internal testing. 213 test_pipeline: Used for internal testing. 214 """ 215 known_args, pipeline_args = parse_known_args(argv) 216 pipeline_options = PipelineOptions(pipeline_args) 217 pipeline_options.view_as(SetupOptions).save_main_session = save_main_session 218 219 if not model_class: 220 model_class = maskrcnn_resnet50_fpn 221 model_params = {'num_classes': 91} 222 223 model_handler = PytorchModelHandlerTensor( 224 state_dict_path=known_args.model_state_dict_path, 225 model_class=model_class, 226 model_params=model_params) 227 228 pipeline = test_pipeline 229 if not test_pipeline: 230 pipeline = beam.Pipeline(options=pipeline_options) 231 232 filename_value_pair = ( 233 pipeline 234 | 'ReadImageNames' >> beam.io.ReadFromText(known_args.input) 235 | 'FilterEmptyLines' >> beam.ParDo(filter_empty_lines) 236 | 'ReadImageData' >> beam.Map( 237 lambda image_name: read_image( 238 image_file_name=image_name, path_to_dir=known_args.images_dir)) 239 | 'PreprocessImages' >> beam.MapTuple( 240 lambda file_name, data: (file_name, preprocess_image(data)))) 241 predictions = ( 242 filename_value_pair 243 | 'PyTorchRunInference' >> RunInference(KeyedModelHandler(model_handler)) 244 | 'ProcessOutput' >> beam.ParDo(PostProcessor())) 245 246 _ = predictions | "WriteOutput" >> beam.io.WriteToText( 247 known_args.output, shard_name_template='', append_trailing_newlines=True) 248 249 result = pipeline.run() 250 result.wait_until_finish() 251 return result 252 253 254 if __name__ == '__main__': 255 logging.getLogger().setLevel(logging.INFO) 256 run()