github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/examples/inference/onnx_sentiment_classification.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """"A pipeline that uses RunInference to perform sentiment classification
    19  using RoBERTa.
    20  
    21  This pipeline takes sentences from a custom text file, and then uses RoBERTa
    22  from Hugging Face to predict the sentiment of a given review. The pipeline
    23  then writes the prediction to an output file in which users can then compare against true labels.
    24  
    25  Model is fine-tuned RoBERTa from
    26  https://github.com/SeldonIO/seldon-models/blob/master/pytorch/moviesentiment_roberta/pytorch-roberta-onnx.ipynb # pylint: disable=line-too-long
    27  """
    28  
    29  import argparse
    30  import logging
    31  from typing import Iterable
    32  from typing import Iterator
    33  from typing import Tuple
    34  
    35  import numpy as np
    36  
    37  import apache_beam as beam
    38  import torch
    39  from apache_beam.ml.inference.base import KeyedModelHandler
    40  from apache_beam.ml.inference.base import PredictionResult
    41  from apache_beam.ml.inference.base import RunInference
    42  from apache_beam.ml.inference.onnx_inference import OnnxModelHandlerNumpy
    43  from apache_beam.options.pipeline_options import PipelineOptions
    44  from apache_beam.options.pipeline_options import SetupOptions
    45  from apache_beam.runners.runner import PipelineResult
    46  from transformers import RobertaTokenizer
    47  
    48  
    49  def tokenize_sentence(text: str,
    50                        tokenizer: RobertaTokenizer) -> Tuple[str, torch.Tensor]:
    51    tokenized_sentence = tokenizer.encode(text, add_special_tokens=True)
    52  
    53    # Workaround to manually remove batch dim until we have the feature to
    54    # add optional batching flag.
    55    # TODO(https://github.com/apache/beam/issues/21863): Remove once optional
    56    # batching flag added
    57    return text, torch.tensor(tokenized_sentence).numpy()
    58  
    59  
    60  def filter_empty_lines(text: str) -> Iterator[str]:
    61    if len(text.strip()) > 0:
    62      yield text
    63  
    64  
    65  class PostProcessor(beam.DoFn):
    66    def process(self, element: Tuple[str, PredictionResult]) -> Iterable[str]:
    67      filename, prediction_result = element
    68      prediction = np.argmax(prediction_result.inference, axis=0)
    69      yield filename + ';' + str(prediction)
    70  
    71  
    72  def parse_known_args(argv):
    73    """Parses args for the workflow."""
    74    parser = argparse.ArgumentParser()
    75    parser.add_argument(
    76        '--input',
    77        dest='input',
    78        help='Path to the text file containing sentences.')
    79    parser.add_argument(
    80        '--output',
    81        dest='output',
    82        required=True,
    83        help='Path of file in which to save the output predictions.')
    84    parser.add_argument(
    85        '--model_uri',
    86        dest='model_uri',
    87        required=True,
    88        help="Path to the model's uri.")
    89    return parser.parse_known_args(argv)
    90  
    91  
    92  def run(
    93      argv=None, save_main_session=True, test_pipeline=None) -> PipelineResult:
    94    """
    95    Args:
    96      argv: Command line arguments defined for this example.
    97      save_main_session: Used for internal testing.
    98      test_pipeline: Used for internal testing.
    99    """
   100    known_args, pipeline_args = parse_known_args(argv)
   101    pipeline_options = PipelineOptions(pipeline_args)
   102    pipeline_options.view_as(SetupOptions).save_main_session = save_main_session
   103  
   104    # TODO: Remove once nested tensors https://github.com/pytorch/nestedtensor
   105    # is officially released.
   106    class OnnxNoBatchModelHandler(OnnxModelHandlerNumpy):
   107      """Wrapper to OnnxModelHandlerNumpy to limit batch size to 1.
   108  
   109      The tokenized strings generated from RobertaTokenizer may have different
   110      lengths, which doesn't work with torch.stack() in current RunInference
   111      implementation since stack() requires tensors to be the same size.
   112  
   113      Restricting max_batch_size to 1 means there is only 1 example per `batch`
   114      in the run_inference() call.
   115      """
   116      def batch_elements_kwargs(self):
   117        return {'max_batch_size': 1}
   118  
   119    model_handler = OnnxNoBatchModelHandler(model_uri=known_args.model_uri)
   120  
   121    pipeline = test_pipeline
   122    if not test_pipeline:
   123      pipeline = beam.Pipeline(options=pipeline_options)
   124  
   125    tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
   126  
   127    text = (pipeline | 'ReadSentences' >> beam.io.ReadFromText(known_args.input))
   128    text_and_tokenized_text_tuple = (
   129        text
   130        | 'FilterEmptyLines' >> beam.ParDo(filter_empty_lines)
   131        |
   132        'TokenizeSentence' >> beam.Map(lambda x: tokenize_sentence(x, tokenizer)))
   133    output = (
   134        text_and_tokenized_text_tuple
   135        | 'PyTorchRunInference' >> RunInference(KeyedModelHandler(model_handler))
   136        | 'ProcessOutput' >> beam.ParDo(PostProcessor()))
   137    _ = output | "WriteOutput" >> beam.io.WriteToText(
   138        known_args.output, shard_name_template='', append_trailing_newlines=True)
   139  
   140    result = pipeline.run()
   141    result.wait_until_finish()
   142    return result
   143  
   144  
   145  if __name__ == '__main__':
   146    logging.getLogger().setLevel(logging.INFO)
   147    run()