github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/examples/inference/onnx_sentiment_classification.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """"A pipeline that uses RunInference to perform sentiment classification 19 using RoBERTa. 20 21 This pipeline takes sentences from a custom text file, and then uses RoBERTa 22 from Hugging Face to predict the sentiment of a given review. The pipeline 23 then writes the prediction to an output file in which users can then compare against true labels. 24 25 Model is fine-tuned RoBERTa from 26 https://github.com/SeldonIO/seldon-models/blob/master/pytorch/moviesentiment_roberta/pytorch-roberta-onnx.ipynb # pylint: disable=line-too-long 27 """ 28 29 import argparse 30 import logging 31 from typing import Iterable 32 from typing import Iterator 33 from typing import Tuple 34 35 import numpy as np 36 37 import apache_beam as beam 38 import torch 39 from apache_beam.ml.inference.base import KeyedModelHandler 40 from apache_beam.ml.inference.base import PredictionResult 41 from apache_beam.ml.inference.base import RunInference 42 from apache_beam.ml.inference.onnx_inference import OnnxModelHandlerNumpy 43 from apache_beam.options.pipeline_options import PipelineOptions 44 from apache_beam.options.pipeline_options import SetupOptions 45 from apache_beam.runners.runner import PipelineResult 46 from transformers import RobertaTokenizer 47 48 49 def tokenize_sentence(text: str, 50 tokenizer: RobertaTokenizer) -> Tuple[str, torch.Tensor]: 51 tokenized_sentence = tokenizer.encode(text, add_special_tokens=True) 52 53 # Workaround to manually remove batch dim until we have the feature to 54 # add optional batching flag. 55 # TODO(https://github.com/apache/beam/issues/21863): Remove once optional 56 # batching flag added 57 return text, torch.tensor(tokenized_sentence).numpy() 58 59 60 def filter_empty_lines(text: str) -> Iterator[str]: 61 if len(text.strip()) > 0: 62 yield text 63 64 65 class PostProcessor(beam.DoFn): 66 def process(self, element: Tuple[str, PredictionResult]) -> Iterable[str]: 67 filename, prediction_result = element 68 prediction = np.argmax(prediction_result.inference, axis=0) 69 yield filename + ';' + str(prediction) 70 71 72 def parse_known_args(argv): 73 """Parses args for the workflow.""" 74 parser = argparse.ArgumentParser() 75 parser.add_argument( 76 '--input', 77 dest='input', 78 help='Path to the text file containing sentences.') 79 parser.add_argument( 80 '--output', 81 dest='output', 82 required=True, 83 help='Path of file in which to save the output predictions.') 84 parser.add_argument( 85 '--model_uri', 86 dest='model_uri', 87 required=True, 88 help="Path to the model's uri.") 89 return parser.parse_known_args(argv) 90 91 92 def run( 93 argv=None, save_main_session=True, test_pipeline=None) -> PipelineResult: 94 """ 95 Args: 96 argv: Command line arguments defined for this example. 97 save_main_session: Used for internal testing. 98 test_pipeline: Used for internal testing. 99 """ 100 known_args, pipeline_args = parse_known_args(argv) 101 pipeline_options = PipelineOptions(pipeline_args) 102 pipeline_options.view_as(SetupOptions).save_main_session = save_main_session 103 104 # TODO: Remove once nested tensors https://github.com/pytorch/nestedtensor 105 # is officially released. 106 class OnnxNoBatchModelHandler(OnnxModelHandlerNumpy): 107 """Wrapper to OnnxModelHandlerNumpy to limit batch size to 1. 108 109 The tokenized strings generated from RobertaTokenizer may have different 110 lengths, which doesn't work with torch.stack() in current RunInference 111 implementation since stack() requires tensors to be the same size. 112 113 Restricting max_batch_size to 1 means there is only 1 example per `batch` 114 in the run_inference() call. 115 """ 116 def batch_elements_kwargs(self): 117 return {'max_batch_size': 1} 118 119 model_handler = OnnxNoBatchModelHandler(model_uri=known_args.model_uri) 120 121 pipeline = test_pipeline 122 if not test_pipeline: 123 pipeline = beam.Pipeline(options=pipeline_options) 124 125 tokenizer = RobertaTokenizer.from_pretrained('roberta-base') 126 127 text = (pipeline | 'ReadSentences' >> beam.io.ReadFromText(known_args.input)) 128 text_and_tokenized_text_tuple = ( 129 text 130 | 'FilterEmptyLines' >> beam.ParDo(filter_empty_lines) 131 | 132 'TokenizeSentence' >> beam.Map(lambda x: tokenize_sentence(x, tokenizer))) 133 output = ( 134 text_and_tokenized_text_tuple 135 | 'PyTorchRunInference' >> RunInference(KeyedModelHandler(model_handler)) 136 | 'ProcessOutput' >> beam.ParDo(PostProcessor())) 137 _ = output | "WriteOutput" >> beam.io.WriteToText( 138 known_args.output, shard_name_template='', append_trailing_newlines=True) 139 140 result = pipeline.run() 141 result.wait_until_finish() 142 return result 143 144 145 if __name__ == '__main__': 146 logging.getLogger().setLevel(logging.INFO) 147 run()