github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/examples/inference/pytorch_image_classification.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """A pipeline that uses RunInference API to perform image classification.""" 19 20 import argparse 21 import io 22 import logging 23 import os 24 from typing import Iterator 25 from typing import Optional 26 from typing import Tuple 27 28 import apache_beam as beam 29 import torch 30 from apache_beam.io.filesystems import FileSystems 31 from apache_beam.ml.inference.base import KeyedModelHandler 32 from apache_beam.ml.inference.base import PredictionResult 33 from apache_beam.ml.inference.base import RunInference 34 from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor 35 from apache_beam.options.pipeline_options import PipelineOptions 36 from apache_beam.options.pipeline_options import SetupOptions 37 from apache_beam.runners.runner import PipelineResult 38 from PIL import Image 39 from torchvision import models 40 from torchvision import transforms 41 42 43 def read_image(image_file_name: str, 44 path_to_dir: Optional[str] = None) -> Tuple[str, Image.Image]: 45 if path_to_dir is not None: 46 image_file_name = os.path.join(path_to_dir, image_file_name) 47 with FileSystems().open(image_file_name, 'r') as file: 48 data = Image.open(io.BytesIO(file.read())).convert('RGB') 49 return image_file_name, data 50 51 52 def preprocess_image(data: Image.Image) -> torch.Tensor: 53 image_size = (224, 224) 54 # Pre-trained PyTorch models expect input images normalized with the 55 # below values (see: https://pytorch.org/vision/stable/models.html) 56 normalize = transforms.Normalize( 57 mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 58 transform = transforms.Compose([ 59 transforms.Resize(image_size), 60 transforms.ToTensor(), 61 normalize, 62 ]) 63 return transform(data) 64 65 66 def filter_empty_lines(text: str) -> Iterator[str]: 67 if len(text.strip()) > 0: 68 yield text 69 70 71 def parse_known_args(argv): 72 """Parses args for the workflow.""" 73 parser = argparse.ArgumentParser() 74 parser.add_argument( 75 '--input', 76 dest='input', 77 required=True, 78 help='Path to the text file containing image names.') 79 parser.add_argument( 80 '--output', 81 dest='output', 82 required=True, 83 help='Path where to save output predictions.' 84 ' text file.') 85 parser.add_argument( 86 '--model_state_dict_path', 87 dest='model_state_dict_path', 88 required=True, 89 help="Path to the model's state_dict.") 90 parser.add_argument( 91 '--images_dir', 92 default=None, 93 help='Path to the directory where images are stored.' 94 'Not required if image names in the input file have absolute path.') 95 return parser.parse_known_args(argv) 96 97 98 def run( 99 argv=None, 100 model_class=None, 101 model_params=None, 102 save_main_session=True, 103 device='CPU', 104 test_pipeline=None) -> PipelineResult: 105 """ 106 Args: 107 argv: Command line arguments defined for this example. 108 model_class: Reference to the class definition of the model. 109 model_params: Parameters passed to the constructor of the model_class. 110 These will be used to instantiate the model object in the 111 RunInference API. 112 save_main_session: Used for internal testing. 113 device: Device to be used on the Runner. Choices are (CPU, GPU). 114 test_pipeline: Used for internal testing. 115 """ 116 known_args, pipeline_args = parse_known_args(argv) 117 pipeline_options = PipelineOptions(pipeline_args) 118 pipeline_options.view_as(SetupOptions).save_main_session = save_main_session 119 120 if not model_class: 121 # default model class will be mobilenet with pretrained weights. 122 model_class = models.mobilenet_v2 123 model_params = {'num_classes': 1000} 124 125 def preprocess(image_name: str) -> Tuple[str, torch.Tensor]: 126 image_name, image = read_image( 127 image_file_name=image_name, 128 path_to_dir=known_args.images_dir) 129 return (image_name, preprocess_image(image)) 130 131 def postprocess(element: Tuple[str, PredictionResult]) -> str: 132 filename, prediction_result = element 133 prediction = torch.argmax(prediction_result.inference, dim=0) 134 return filename + ',' + str(prediction.item()) 135 136 # In this example we pass keyed inputs to RunInference transform. 137 # Therefore, we use KeyedModelHandler wrapper over PytorchModelHandler. 138 model_handler = KeyedModelHandler( 139 PytorchModelHandlerTensor( 140 state_dict_path=known_args.model_state_dict_path, 141 model_class=model_class, 142 model_params=model_params, 143 device=device, 144 min_batch_size=10, 145 max_batch_size=100)).with_preprocess_fn( 146 preprocess).with_postprocess_fn(postprocess) 147 148 pipeline = test_pipeline 149 if not test_pipeline: 150 pipeline = beam.Pipeline(options=pipeline_options) 151 152 filename_value_pair = ( 153 pipeline 154 | 'ReadImageNames' >> beam.io.ReadFromText(known_args.input) 155 | 'FilterEmptyLines' >> beam.ParDo(filter_empty_lines)) 156 predictions = ( 157 filename_value_pair 158 | 'PyTorchRunInference' >> RunInference(model_handler)) 159 160 predictions | "WriteOutputToGCS" >> beam.io.WriteToText( # pylint: disable=expression-not-assigned 161 known_args.output, 162 shard_name_template='', 163 append_trailing_newlines=True) 164 165 result = pipeline.run() 166 result.wait_until_finish() 167 return result 168 169 170 if __name__ == '__main__': 171 logging.getLogger().setLevel(logging.INFO) 172 run()