github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/examples/inference/pytorch_language_modeling.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """"A pipeline that uses RunInference to perform Language Modeling with Bert.
    19  
    20  This pipeline takes sentences from a custom text file, converts the last word
    21  of the sentence into a [MASK] token, and then uses the BertForMaskedLM from
    22  Hugging Face to predict the best word for the masked token given all the words
    23  already in the sentence. The pipeline then writes the prediction to an output
    24  file in which users can then compare against the original sentence.
    25  """
    26  
    27  import argparse
    28  import logging
    29  from typing import Dict
    30  from typing import Iterable
    31  from typing import Iterator
    32  from typing import Tuple
    33  
    34  import apache_beam as beam
    35  import torch
    36  from apache_beam.ml.inference.base import KeyedModelHandler
    37  from apache_beam.ml.inference.base import PredictionResult
    38  from apache_beam.ml.inference.base import RunInference
    39  from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerKeyedTensor
    40  from apache_beam.options.pipeline_options import PipelineOptions
    41  from apache_beam.options.pipeline_options import SetupOptions
    42  from apache_beam.runners.runner import PipelineResult
    43  from transformers import BertConfig
    44  from transformers import BertForMaskedLM
    45  from transformers import BertTokenizer
    46  
    47  
    48  def add_mask_to_last_word(text: str) -> Tuple[str, str]:
    49    text_list = text.split()
    50    return text, ' '.join(text_list[:-2] + ['[MASK]', text_list[-1]])
    51  
    52  
    53  def tokenize_sentence(
    54      text_and_mask: Tuple[str, str],
    55      bert_tokenizer: BertTokenizer) -> Tuple[str, Dict[str, torch.Tensor]]:
    56    text, masked_text = text_and_mask
    57    tokenized_sentence = bert_tokenizer.encode_plus(
    58        masked_text, return_tensors="pt")
    59  
    60    # Workaround to manually remove batch dim until we have the feature to
    61    # add optional batching flag.
    62    # TODO(https://github.com/apache/beam/issues/21863): Remove once optional
    63    # batching flag added
    64    return text, {
    65        k: torch.squeeze(v)
    66        for k, v in dict(tokenized_sentence).items()
    67    }
    68  
    69  
    70  def filter_empty_lines(text: str) -> Iterator[str]:
    71    if len(text.strip()) > 0:
    72      yield text
    73  
    74  
    75  class PostProcessor(beam.DoFn):
    76    """Processes the PredictionResult to get the predicted word.
    77  
    78    The logits are the output of the BERT Model. After applying a softmax
    79    activation function to the logits, we get probabilistic distributions for each
    80    of the words in BERT’s vocabulary. We can get the word with the highest
    81    probability of being a candidate replacement word by taking the argmax.
    82    """
    83    def __init__(self, bert_tokenizer: BertTokenizer):
    84      super().__init__()
    85      self.bert_tokenizer = bert_tokenizer
    86  
    87    def process(self, element: Tuple[str, PredictionResult]) -> Iterable[str]:
    88      text, prediction_result = element
    89      inputs = prediction_result.example
    90      logits = prediction_result.inference['logits']
    91      mask_token_index = (
    92          inputs['input_ids'] == self.bert_tokenizer.mask_token_id).nonzero(
    93              as_tuple=True)[0]
    94      predicted_token_id = logits[mask_token_index].argmax(axis=-1)
    95      decoded_word = self.bert_tokenizer.decode(predicted_token_id)
    96      yield text + ';' + decoded_word
    97  
    98  
    99  def parse_known_args(argv):
   100    """Parses args for the workflow."""
   101    parser = argparse.ArgumentParser()
   102    parser.add_argument(
   103        '--input',
   104        dest='input',
   105        help='Path to the text file containing sentences.')
   106    parser.add_argument(
   107        '--output',
   108        dest='output',
   109        required=True,
   110        help='Path of file in which to save the output predictions.')
   111    parser.add_argument(
   112        '--bert_tokenizer',
   113        dest='bert_tokenizer',
   114        default='bert-base-uncased',
   115        help='bert uncased model. This can be base model or large model')
   116    parser.add_argument(
   117        '--model_state_dict_path',
   118        dest='model_state_dict_path',
   119        required=True,
   120        help="Path to the model's state_dict.")
   121    return parser.parse_known_args(argv)
   122  
   123  
   124  def run(
   125      argv=None,
   126      model_class=None,
   127      model_params=None,
   128      save_main_session=True,
   129      test_pipeline=None) -> PipelineResult:
   130    """
   131    Args:
   132      argv: Command line arguments defined for this example.
   133      model_class: Reference to the class definition of the model.
   134                  If None, BertForMaskedLM will be used as default .
   135      model_params: Parameters passed to the constructor of the model_class.
   136                    These will be used to instantiate the model object in the
   137                    RunInference API.
   138      save_main_session: Used for internal testing.
   139      test_pipeline: Used for internal testing.
   140    """
   141    known_args, pipeline_args = parse_known_args(argv)
   142    pipeline_options = PipelineOptions(pipeline_args)
   143    pipeline_options.view_as(SetupOptions).save_main_session = save_main_session
   144  
   145    if not model_class:
   146      model_config = BertConfig.from_pretrained(
   147          known_args.bert_tokenizer, is_decoder=False, return_dict=True)
   148      model_class = BertForMaskedLM
   149      model_params = {'config': model_config}
   150  
   151    # TODO: Remove once nested tensors https://github.com/pytorch/nestedtensor
   152    # is officially released.
   153    class PytorchNoBatchModelHandler(PytorchModelHandlerKeyedTensor):
   154      """Wrapper to PytorchModelHandler to limit batch size to 1.
   155  
   156      The tokenized strings generated from BertTokenizer may have different
   157      lengths, which doesn't work with torch.stack() in current RunInference
   158      implementation since stack() requires tensors to be the same size.
   159  
   160      Restricting max_batch_size to 1 means there is only 1 example per `batch`
   161      in the run_inference() call.
   162      """
   163      def batch_elements_kwargs(self):
   164        return {'max_batch_size': 1}
   165  
   166    model_handler = PytorchNoBatchModelHandler(
   167        state_dict_path=known_args.model_state_dict_path,
   168        model_class=model_class,
   169        model_params=model_params)
   170  
   171    pipeline = test_pipeline
   172    if not test_pipeline:
   173      pipeline = beam.Pipeline(options=pipeline_options)
   174  
   175    bert_tokenizer = BertTokenizer.from_pretrained(known_args.bert_tokenizer)
   176  
   177    if not known_args.input:
   178      text = (pipeline | 'CreateSentences' >> beam.Create([
   179        'The capital of France is Paris .',
   180        'It is raining cats and dogs .',
   181        'He looked up and saw the sun and stars .',
   182        'Today is Monday and tomorrow is Tuesday .',
   183        'There are 5 coconuts on this palm tree .',
   184        'The richest person in the world is not here .',
   185        'Malls are amazing places to shop because you can find everything you need under one roof .', # pylint: disable=line-too-long
   186        'This audiobook is sure to liquefy your brain .',
   187        'The secret ingredient to his wonderful life was gratitude .',
   188        'The biggest animal in the world is the whale .',
   189      ]))
   190    else:
   191      text = (
   192          pipeline | 'ReadSentences' >> beam.io.ReadFromText(known_args.input))
   193    text_and_tokenized_text_tuple = (
   194        text
   195        | 'FilterEmptyLines' >> beam.ParDo(filter_empty_lines)
   196        | 'AddMask' >> beam.Map(add_mask_to_last_word)
   197        | 'TokenizeSentence' >>
   198        beam.Map(lambda x: tokenize_sentence(x, bert_tokenizer)))
   199    output = (
   200        text_and_tokenized_text_tuple
   201        | 'PyTorchRunInference' >> RunInference(KeyedModelHandler(model_handler))
   202        | 'ProcessOutput' >> beam.ParDo(
   203            PostProcessor(bert_tokenizer=bert_tokenizer)))
   204    output | "WriteOutput" >> beam.io.WriteToText( # pylint: disable=expression-not-assigned
   205      known_args.output,
   206      shard_name_template='',
   207      append_trailing_newlines=True)
   208  
   209    result = pipeline.run()
   210    result.wait_until_finish()
   211    return result
   212  
   213  
   214  if __name__ == '__main__':
   215    logging.getLogger().setLevel(logging.INFO)
   216    run()