github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/examples/inference/runinference_metrics/main.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """This file contains the pipeline for loading a ML model, and exploring
    19  the different RunInference metrics."""
    20  import argparse
    21  import logging
    22  import sys
    23  
    24  import apache_beam as beam
    25  import config as cfg
    26  from apache_beam.ml.inference import RunInference
    27  from apache_beam.ml.inference.base import KeyedModelHandler
    28  from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerKeyedTensor
    29  from pipeline.options import get_pipeline_options
    30  from pipeline.transformations import CustomPytorchModelHandlerKeyedTensor
    31  from pipeline.transformations import HuggingFaceStripBatchingWrapper
    32  from pipeline.transformations import PostProcessor
    33  from pipeline.transformations import Tokenize
    34  from transformers import DistilBertConfig
    35  
    36  
    37  def parse_arguments(argv):
    38    """
    39      Parses the arguments passed to the command line and
    40      returns them as an object
    41      Args:
    42        argv: The arguments passed to the command line.
    43      Returns:
    44        The arguments that are being passed in.
    45      """
    46    parser = argparse.ArgumentParser(description="benchmark-runinference")
    47  
    48    parser.add_argument(
    49        "-m",
    50        "--mode",
    51        help="Mode to run pipeline in.",
    52        choices=["local", "cloud"],
    53        default="local",
    54    )
    55    parser.add_argument(
    56        "-p",
    57        "--project",
    58        help="GCP project to run pipeline on.",
    59        default=cfg.PROJECT_ID,
    60    )
    61    parser.add_argument(
    62        "-d",
    63        "--device",
    64        help="Device to run the dataflow job on",
    65        choices=["CPU", "GPU"],
    66        default="CPU",
    67    )
    68  
    69    args, _ = parser.parse_known_args(args=argv)
    70    return args
    71  
    72  
    73  def run():
    74    """
    75      Runs the pipeline that loads a transformer based text classification model
    76      and does inference on a list of sentences.
    77      At the end of pipeline, different metrics like latency,
    78      throughput and others are printed.
    79      """
    80    args = parse_arguments(sys.argv)
    81  
    82    inputs = [
    83        "This is the worst food I have ever eaten",
    84        "In my soul and in my heart, I’m convinced I’m wrong!",
    85        "Be with me always—take any form—drive me mad!"\
    86        "only do not leave me in this abyss, where I cannot find you!",
    87        "Do I want to live? Would you like to live with your soul in the grave?",
    88        "Honest people don’t hide their deeds.",
    89        "Nelly, I am Heathcliff!  He’s always,"\
    90        "always in my mind: not as a pleasure,"\
    91        "any more than I am always a pleasure to myself, but as my own being.",
    92    ] * 1000
    93  
    94    pipeline_options = get_pipeline_options(
    95        job_name=cfg.JOB_NAME,
    96        num_workers=cfg.NUM_WORKERS,
    97        project=args.project,
    98        mode=args.mode,
    99        device=args.device,
   100    )
   101    model_handler_class = (
   102        PytorchModelHandlerKeyedTensor
   103        if args.device == "GPU" else CustomPytorchModelHandlerKeyedTensor)
   104    device = "cuda:0" if args.device == "GPU" else args.device
   105    model_handler = model_handler_class(
   106        state_dict_path=cfg.MODEL_STATE_DICT_PATH,
   107        model_class=HuggingFaceStripBatchingWrapper,
   108        model_params={
   109            "config": DistilBertConfig.from_pretrained(cfg.MODEL_CONFIG_PATH)
   110        },
   111        device=device,
   112    )
   113  
   114    with beam.Pipeline(options=pipeline_options) as pipeline:
   115      _ = (
   116          pipeline
   117          | "Create inputs" >> beam.Create(inputs)
   118          | "Tokenize" >> beam.ParDo(Tokenize(cfg.TOKENIZER_NAME))
   119          | "Inference" >>
   120          RunInference(model_handler=KeyedModelHandler(model_handler))
   121          | "Decode Predictions" >> beam.ParDo(PostProcessor()))
   122    metrics = pipeline.result.metrics().query(beam.metrics.MetricsFilter())
   123    logging.info(metrics)
   124  
   125  
   126  if __name__ == "__main__":
   127    run()