github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/examples/inference/tensorrt_text_classification.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """A pipeline to demonstrate usage of TensorRT with RunInference
    19  for a text classification model. This pipeline reads data from a text
    20  file, preprocesses the data, and then uses RunInference to generate
    21  predictions from the text classification TensorRT engine. Next,
    22  it postprocesses the RunInference outputs to print the input and
    23  the predicted class label.
    24  It also prints metrics provided by RunInference.
    25  """
    26  
    27  import argparse
    28  import logging
    29  
    30  import numpy as np
    31  
    32  import apache_beam as beam
    33  from apache_beam.ml.inference.base import RunInference
    34  from apache_beam.ml.inference.tensorrt_inference import TensorRTEngineHandlerNumPy
    35  from apache_beam.options.pipeline_options import PipelineOptions
    36  from apache_beam.options.pipeline_options import SetupOptions
    37  from transformers import AutoTokenizer
    38  
    39  
    40  class Preprocess(beam.DoFn):
    41    """Processes the input sentences to tokenize them.
    42  
    43    The input sentences are tokenized because the
    44    model is expecting tokens.
    45    """
    46    def __init__(self, tokenizer: AutoTokenizer):
    47      self._tokenizer = tokenizer
    48  
    49    def process(self, element):
    50      inputs = self._tokenizer(
    51          element, return_tensors="np", padding="max_length", max_length=128)
    52      return inputs.input_ids
    53  
    54  
    55  class Postprocess(beam.DoFn):
    56    """Processes the PredictionResult to get the predicted class.
    57  
    58    The logits are the output of the TensorRT engine.
    59    We can get the class label by getting the index of
    60    maximum logit using argmax.
    61    """
    62    def __init__(self, tokenizer: AutoTokenizer):
    63      self._tokenizer = tokenizer
    64  
    65    def process(self, element):
    66      decoded_input = self._tokenizer.decode(
    67          element.example, skip_special_tokens=True)
    68      logits = element.inference[0]
    69      argmax = np.argmax(logits)
    70      output = "Positive" if argmax == 1 else "Negative"
    71      yield decoded_input, output
    72  
    73  
    74  def parse_known_args(argv):
    75    """Parses args for the workflow."""
    76    parser = argparse.ArgumentParser()
    77    parser.add_argument(
    78        '--input',
    79        dest='input',
    80        required=True,
    81        help='Path to the text file containing sentences.')
    82    parser.add_argument(
    83        '--trt_model_path',
    84        dest='trt_model_path',
    85        required=True,
    86        help='Path to the pre-built textattack/bert-base-uncased-SST-2'
    87        'TensorRT engine.')
    88    parser.add_argument(
    89        '--model_id',
    90        dest='model_id',
    91        default="textattack/bert-base-uncased-SST-2",
    92        help="name of model.")
    93    return parser.parse_known_args(argv)
    94  
    95  
    96  def run(
    97      argv=None,
    98      save_main_session=True,
    99  ):
   100    known_args, pipeline_args = parse_known_args(argv)
   101    pipeline_options = PipelineOptions(pipeline_args)
   102    pipeline_options.view_as(SetupOptions).save_main_session = save_main_session
   103  
   104    model_handler = TensorRTEngineHandlerNumPy(
   105        min_batch_size=1,
   106        max_batch_size=1,
   107        engine_path=known_args.trt_model_path,
   108    )
   109  
   110    tokenizer = AutoTokenizer.from_pretrained(known_args.model_id)
   111  
   112    with beam.Pipeline(options=pipeline_options) as pipeline:
   113      _ = (
   114          pipeline
   115          | "ReadSentences" >> beam.io.ReadFromText(known_args.input)
   116          | "Preprocess" >> beam.ParDo(Preprocess(tokenizer=tokenizer))
   117          | "RunInference" >> RunInference(model_handler=model_handler)
   118          | "PostProcess" >> beam.ParDo(Postprocess(tokenizer=tokenizer))
   119          | "LogResult" >> beam.Map(logging.info))
   120    metrics = pipeline.result.metrics().query(beam.metrics.MetricsFilter())
   121    logging.info(metrics)
   122  
   123  
   124  if __name__ == '__main__':
   125    logging.getLogger().setLevel(logging.INFO)
   126    run()