github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/examples/inference/tensorflow_mnist_classification.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 import argparse 19 import logging 20 from typing import Iterable 21 from typing import Tuple 22 23 import numpy 24 25 import apache_beam as beam 26 from apache_beam.ml.inference.base import KeyedModelHandler 27 from apache_beam.ml.inference.base import PredictionResult 28 from apache_beam.ml.inference.base import RunInference 29 from apache_beam.ml.inference.tensorflow_inference import ModelType 30 from apache_beam.ml.inference.tensorflow_inference import TFModelHandlerNumpy 31 from apache_beam.options.pipeline_options import PipelineOptions 32 from apache_beam.options.pipeline_options import SetupOptions 33 from apache_beam.runners.runner import PipelineResult 34 35 36 def process_input(row: str) -> Tuple[int, numpy.ndarray]: 37 data = row.split(',') 38 label, pixels = int(data[0]), data[1:] 39 pixels = [int(pixel) for pixel in pixels] 40 # the trained model accepts the input of shape 28x28 41 pixels = numpy.array(pixels).reshape((28, 28, 1)) 42 return label, pixels 43 44 45 class PostProcessor(beam.DoFn): 46 """Process the PredictionResult to get the predicted label. 47 Returns a comma separated string with true label and predicted label. 48 """ 49 def process(self, element: Tuple[int, PredictionResult]) -> Iterable[str]: 50 label, prediction_result = element 51 prediction = numpy.argmax(prediction_result.inference, axis=0) 52 yield '{},{}'.format(label, prediction) 53 54 55 def parse_known_args(argv): 56 """Parses args for the workflow.""" 57 parser = argparse.ArgumentParser() 58 parser.add_argument( 59 '--input', 60 dest='input', 61 required=True, 62 help='text file with comma separated int values.') 63 parser.add_argument( 64 '--output', 65 dest='output', 66 required=True, 67 help='Path to save output predictions.') 68 parser.add_argument( 69 '--model_path', 70 dest='model_path', 71 required=True, 72 help='Path to load the Tensorflow model for Inference.') 73 return parser.parse_known_args(argv) 74 75 76 def run( 77 argv=None, save_main_session=True, test_pipeline=None) -> PipelineResult: 78 """ 79 Args: 80 argv: Command line arguments defined for this example. 81 save_main_session: Used for internal testing. 82 test_pipeline: Used for internal testing. 83 """ 84 known_args, pipeline_args = parse_known_args(argv) 85 pipeline_options = PipelineOptions(pipeline_args) 86 pipeline_options.view_as(SetupOptions).save_main_session = save_main_session 87 88 # In this example we pass keyed inputs to RunInference transform. 89 # Therefore, we use KeyedModelHandler wrapper over TFModelHandlerNumpy. 90 model_loader = KeyedModelHandler( 91 TFModelHandlerNumpy( 92 model_uri=known_args.model_path, model_type=ModelType.SAVED_MODEL)) 93 94 pipeline = test_pipeline 95 if not test_pipeline: 96 pipeline = beam.Pipeline(options=pipeline_options) 97 98 label_pixel_tuple = ( 99 pipeline 100 | "ReadFromInput" >> beam.io.ReadFromText(known_args.input) 101 | "PreProcessInputs" >> beam.Map(process_input)) 102 103 predictions = ( 104 label_pixel_tuple 105 | "RunInference" >> RunInference(model_loader) 106 | "PostProcessOutputs" >> beam.ParDo(PostProcessor())) 107 108 _ = predictions | "WriteOutput" >> beam.io.WriteToText( 109 known_args.output, shard_name_template='', append_trailing_newlines=True) 110 111 result = pipeline.run() 112 result.wait_until_finish() 113 return result 114 115 116 if __name__ == '__main__': 117 logging.getLogger().setLevel(logging.INFO) 118 run()