github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/examples/inference/runinference_metrics/main.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """This file contains the pipeline for loading a ML model, and exploring 19 the different RunInference metrics.""" 20 import argparse 21 import logging 22 import sys 23 24 import apache_beam as beam 25 import config as cfg 26 from apache_beam.ml.inference import RunInference 27 from apache_beam.ml.inference.base import KeyedModelHandler 28 from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerKeyedTensor 29 from pipeline.options import get_pipeline_options 30 from pipeline.transformations import CustomPytorchModelHandlerKeyedTensor 31 from pipeline.transformations import HuggingFaceStripBatchingWrapper 32 from pipeline.transformations import PostProcessor 33 from pipeline.transformations import Tokenize 34 from transformers import DistilBertConfig 35 36 37 def parse_arguments(argv): 38 """ 39 Parses the arguments passed to the command line and 40 returns them as an object 41 Args: 42 argv: The arguments passed to the command line. 43 Returns: 44 The arguments that are being passed in. 45 """ 46 parser = argparse.ArgumentParser(description="benchmark-runinference") 47 48 parser.add_argument( 49 "-m", 50 "--mode", 51 help="Mode to run pipeline in.", 52 choices=["local", "cloud"], 53 default="local", 54 ) 55 parser.add_argument( 56 "-p", 57 "--project", 58 help="GCP project to run pipeline on.", 59 default=cfg.PROJECT_ID, 60 ) 61 parser.add_argument( 62 "-d", 63 "--device", 64 help="Device to run the dataflow job on", 65 choices=["CPU", "GPU"], 66 default="CPU", 67 ) 68 69 args, _ = parser.parse_known_args(args=argv) 70 return args 71 72 73 def run(): 74 """ 75 Runs the pipeline that loads a transformer based text classification model 76 and does inference on a list of sentences. 77 At the end of pipeline, different metrics like latency, 78 throughput and others are printed. 79 """ 80 args = parse_arguments(sys.argv) 81 82 inputs = [ 83 "This is the worst food I have ever eaten", 84 "In my soul and in my heart, I’m convinced I’m wrong!", 85 "Be with me always—take any form—drive me mad!"\ 86 "only do not leave me in this abyss, where I cannot find you!", 87 "Do I want to live? Would you like to live with your soul in the grave?", 88 "Honest people don’t hide their deeds.", 89 "Nelly, I am Heathcliff! He’s always,"\ 90 "always in my mind: not as a pleasure,"\ 91 "any more than I am always a pleasure to myself, but as my own being.", 92 ] * 1000 93 94 pipeline_options = get_pipeline_options( 95 job_name=cfg.JOB_NAME, 96 num_workers=cfg.NUM_WORKERS, 97 project=args.project, 98 mode=args.mode, 99 device=args.device, 100 ) 101 model_handler_class = ( 102 PytorchModelHandlerKeyedTensor 103 if args.device == "GPU" else CustomPytorchModelHandlerKeyedTensor) 104 device = "cuda:0" if args.device == "GPU" else args.device 105 model_handler = model_handler_class( 106 state_dict_path=cfg.MODEL_STATE_DICT_PATH, 107 model_class=HuggingFaceStripBatchingWrapper, 108 model_params={ 109 "config": DistilBertConfig.from_pretrained(cfg.MODEL_CONFIG_PATH) 110 }, 111 device=device, 112 ) 113 114 with beam.Pipeline(options=pipeline_options) as pipeline: 115 _ = ( 116 pipeline 117 | "Create inputs" >> beam.Create(inputs) 118 | "Tokenize" >> beam.ParDo(Tokenize(cfg.TOKENIZER_NAME)) 119 | "Inference" >> 120 RunInference(model_handler=KeyedModelHandler(model_handler)) 121 | "Decode Predictions" >> beam.ParDo(PostProcessor())) 122 metrics = pipeline.result.metrics().query(beam.metrics.MetricsFilter()) 123 logging.info(metrics) 124 125 126 if __name__ == "__main__": 127 run()