github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/examples/inference/tensorrt_text_classification.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """A pipeline to demonstrate usage of TensorRT with RunInference 19 for a text classification model. This pipeline reads data from a text 20 file, preprocesses the data, and then uses RunInference to generate 21 predictions from the text classification TensorRT engine. Next, 22 it postprocesses the RunInference outputs to print the input and 23 the predicted class label. 24 It also prints metrics provided by RunInference. 25 """ 26 27 import argparse 28 import logging 29 30 import numpy as np 31 32 import apache_beam as beam 33 from apache_beam.ml.inference.base import RunInference 34 from apache_beam.ml.inference.tensorrt_inference import TensorRTEngineHandlerNumPy 35 from apache_beam.options.pipeline_options import PipelineOptions 36 from apache_beam.options.pipeline_options import SetupOptions 37 from transformers import AutoTokenizer 38 39 40 class Preprocess(beam.DoFn): 41 """Processes the input sentences to tokenize them. 42 43 The input sentences are tokenized because the 44 model is expecting tokens. 45 """ 46 def __init__(self, tokenizer: AutoTokenizer): 47 self._tokenizer = tokenizer 48 49 def process(self, element): 50 inputs = self._tokenizer( 51 element, return_tensors="np", padding="max_length", max_length=128) 52 return inputs.input_ids 53 54 55 class Postprocess(beam.DoFn): 56 """Processes the PredictionResult to get the predicted class. 57 58 The logits are the output of the TensorRT engine. 59 We can get the class label by getting the index of 60 maximum logit using argmax. 61 """ 62 def __init__(self, tokenizer: AutoTokenizer): 63 self._tokenizer = tokenizer 64 65 def process(self, element): 66 decoded_input = self._tokenizer.decode( 67 element.example, skip_special_tokens=True) 68 logits = element.inference[0] 69 argmax = np.argmax(logits) 70 output = "Positive" if argmax == 1 else "Negative" 71 yield decoded_input, output 72 73 74 def parse_known_args(argv): 75 """Parses args for the workflow.""" 76 parser = argparse.ArgumentParser() 77 parser.add_argument( 78 '--input', 79 dest='input', 80 required=True, 81 help='Path to the text file containing sentences.') 82 parser.add_argument( 83 '--trt_model_path', 84 dest='trt_model_path', 85 required=True, 86 help='Path to the pre-built textattack/bert-base-uncased-SST-2' 87 'TensorRT engine.') 88 parser.add_argument( 89 '--model_id', 90 dest='model_id', 91 default="textattack/bert-base-uncased-SST-2", 92 help="name of model.") 93 return parser.parse_known_args(argv) 94 95 96 def run( 97 argv=None, 98 save_main_session=True, 99 ): 100 known_args, pipeline_args = parse_known_args(argv) 101 pipeline_options = PipelineOptions(pipeline_args) 102 pipeline_options.view_as(SetupOptions).save_main_session = save_main_session 103 104 model_handler = TensorRTEngineHandlerNumPy( 105 min_batch_size=1, 106 max_batch_size=1, 107 engine_path=known_args.trt_model_path, 108 ) 109 110 tokenizer = AutoTokenizer.from_pretrained(known_args.model_id) 111 112 with beam.Pipeline(options=pipeline_options) as pipeline: 113 _ = ( 114 pipeline 115 | "ReadSentences" >> beam.io.ReadFromText(known_args.input) 116 | "Preprocess" >> beam.ParDo(Preprocess(tokenizer=tokenizer)) 117 | "RunInference" >> RunInference(model_handler=model_handler) 118 | "PostProcess" >> beam.ParDo(Postprocess(tokenizer=tokenizer)) 119 | "LogResult" >> beam.Map(logging.info)) 120 metrics = pipeline.result.metrics().query(beam.metrics.MetricsFilter()) 121 logging.info(metrics) 122 123 124 if __name__ == '__main__': 125 logging.getLogger().setLevel(logging.INFO) 126 run()