github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/examples/inference/large_language_modeling/main.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one
     3  # or more contributor license agreements. See the NOTICE file
     4  # distributed with this work for additional information
     5  # regarding copyright ownership. The ASF licenses this file
     6  # to you under the Apache License, Version 2.0 (the
     7  # "License"); you may not use this file except in compliance
     8  # with the License. You may obtain a copy of the License at
     9  #
    10  #   http://www.apache.org/licenses/LICENSE-2.0
    11  #
    12  # Unless required by applicable law or agreed to in writing,
    13  # software distributed under the License is distributed on an
    14  # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    15  # KIND, either express or implied. See the License for the
    16  # specific language governing permissions and limitations
    17  # under the License
    18  
    19  """"A pipeline that uses RunInference to perform translation
    20  with a T5 language model.
    21  
    22  This pipeline takes a list of english sentences and then uses
    23  the T5ForConditionalGeneration from Hugging Face to translate the
    24  english sentence into german.
    25  """
    26  import argparse
    27  import sys
    28  
    29  import apache_beam as beam
    30  from apache_beam.ml.inference.base import RunInference
    31  from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor
    32  from apache_beam.ml.inference.pytorch_inference import make_tensor_model_fn
    33  from apache_beam.options.pipeline_options import PipelineOptions
    34  from transformers import AutoConfig
    35  from transformers import AutoTokenizer
    36  from transformers import T5ForConditionalGeneration
    37  
    38  
    39  class Preprocess(beam.DoFn):
    40    def __init__(self, tokenizer: AutoTokenizer):
    41      self._tokenizer = tokenizer
    42  
    43    def process(self, element):
    44      """
    45          Process the raw text input to a format suitable for
    46          T5ForConditionalGeneration model inference
    47  
    48          Args:
    49            element: A string of text
    50  
    51          Returns:
    52            A tokenized example that can be read by the
    53            T5ForConditionalGeneration
    54          """
    55      input_ids = self._tokenizer(
    56          element, return_tensors="pt", padding="max_length",
    57          max_length=512).input_ids
    58      return input_ids
    59  
    60  
    61  class Postprocess(beam.DoFn):
    62    def __init__(self, tokenizer: AutoTokenizer):
    63      self._tokenizer = tokenizer
    64  
    65    def process(self, element):
    66      """
    67          Process the PredictionResult to print the translated texts
    68  
    69          Args:
    70            element: The RunInference output to be processed.
    71          """
    72      decoded_inputs = self._tokenizer.decode(
    73          element.example, skip_special_tokens=True)
    74      decoded_outputs = self._tokenizer.decode(
    75          element.inference, skip_special_tokens=True)
    76      print(f"{decoded_inputs} \t Output: {decoded_outputs}")
    77  
    78  
    79  def parse_args(argv):
    80    """Parses args for the workflow."""
    81    parser = argparse.ArgumentParser()
    82    parser.add_argument(
    83        "--model_state_dict_path",
    84        dest="model_state_dict_path",
    85        required=True,
    86        help="Path to the model's state_dict.",
    87    )
    88    parser.add_argument(
    89        "--model_name",
    90        dest="model_name",
    91        required=False,
    92        help="Path to the model's state_dict.",
    93        default="t5-11b",
    94    )
    95  
    96    return parser.parse_known_args(args=argv)
    97  
    98  
    99  def run():
   100    """
   101      Runs the interjector pipeline which translates English sentences
   102      into German using the RunInference API. """
   103  
   104    known_args, pipeline_args = parse_args(sys.argv)
   105    pipeline_options = PipelineOptions(pipeline_args)
   106  
   107    gen_fn = make_tensor_model_fn('generate')
   108    model_handler = PytorchModelHandlerTensor(
   109        state_dict_path=known_args.model_state_dict_path,
   110        model_class=T5ForConditionalGeneration,
   111        model_params={
   112            "config": AutoConfig.from_pretrained(known_args.model_name)
   113        },
   114        device="cpu",
   115        inference_fn=gen_fn)
   116  
   117    eng_sentences = [
   118        "The house is wonderful.",
   119        "I like to work in NYC.",
   120        "My name is Shubham.",
   121        "I want to work for Google.",
   122        "I am from India."
   123    ]
   124    task_prefix = "translate English to German: "
   125    task_sentences = [task_prefix + sentence for sentence in eng_sentences]
   126    tokenizer = AutoTokenizer.from_pretrained(known_args.model_name)
   127  
   128    # [START Pipeline]
   129    with beam.Pipeline(options=pipeline_options) as pipeline:
   130      _ = (
   131          pipeline
   132          | "CreateInputs" >> beam.Create(task_sentences)
   133          | "Preprocess" >> beam.ParDo(Preprocess(tokenizer=tokenizer))
   134          | "RunInference" >> RunInference(model_handler=model_handler)
   135          | "PostProcess" >> beam.ParDo(Postprocess(tokenizer=tokenizer)))
   136    # [END Pipeline]
   137  
   138  
   139  if __name__ == "__main__":
   140    run()