github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/examples/inference/pytorch_language_modeling.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """"A pipeline that uses RunInference to perform Language Modeling with Bert. 19 20 This pipeline takes sentences from a custom text file, converts the last word 21 of the sentence into a [MASK] token, and then uses the BertForMaskedLM from 22 Hugging Face to predict the best word for the masked token given all the words 23 already in the sentence. The pipeline then writes the prediction to an output 24 file in which users can then compare against the original sentence. 25 """ 26 27 import argparse 28 import logging 29 from typing import Dict 30 from typing import Iterable 31 from typing import Iterator 32 from typing import Tuple 33 34 import apache_beam as beam 35 import torch 36 from apache_beam.ml.inference.base import KeyedModelHandler 37 from apache_beam.ml.inference.base import PredictionResult 38 from apache_beam.ml.inference.base import RunInference 39 from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerKeyedTensor 40 from apache_beam.options.pipeline_options import PipelineOptions 41 from apache_beam.options.pipeline_options import SetupOptions 42 from apache_beam.runners.runner import PipelineResult 43 from transformers import BertConfig 44 from transformers import BertForMaskedLM 45 from transformers import BertTokenizer 46 47 48 def add_mask_to_last_word(text: str) -> Tuple[str, str]: 49 text_list = text.split() 50 return text, ' '.join(text_list[:-2] + ['[MASK]', text_list[-1]]) 51 52 53 def tokenize_sentence( 54 text_and_mask: Tuple[str, str], 55 bert_tokenizer: BertTokenizer) -> Tuple[str, Dict[str, torch.Tensor]]: 56 text, masked_text = text_and_mask 57 tokenized_sentence = bert_tokenizer.encode_plus( 58 masked_text, return_tensors="pt") 59 60 # Workaround to manually remove batch dim until we have the feature to 61 # add optional batching flag. 62 # TODO(https://github.com/apache/beam/issues/21863): Remove once optional 63 # batching flag added 64 return text, { 65 k: torch.squeeze(v) 66 for k, v in dict(tokenized_sentence).items() 67 } 68 69 70 def filter_empty_lines(text: str) -> Iterator[str]: 71 if len(text.strip()) > 0: 72 yield text 73 74 75 class PostProcessor(beam.DoFn): 76 """Processes the PredictionResult to get the predicted word. 77 78 The logits are the output of the BERT Model. After applying a softmax 79 activation function to the logits, we get probabilistic distributions for each 80 of the words in BERT’s vocabulary. We can get the word with the highest 81 probability of being a candidate replacement word by taking the argmax. 82 """ 83 def __init__(self, bert_tokenizer: BertTokenizer): 84 super().__init__() 85 self.bert_tokenizer = bert_tokenizer 86 87 def process(self, element: Tuple[str, PredictionResult]) -> Iterable[str]: 88 text, prediction_result = element 89 inputs = prediction_result.example 90 logits = prediction_result.inference['logits'] 91 mask_token_index = ( 92 inputs['input_ids'] == self.bert_tokenizer.mask_token_id).nonzero( 93 as_tuple=True)[0] 94 predicted_token_id = logits[mask_token_index].argmax(axis=-1) 95 decoded_word = self.bert_tokenizer.decode(predicted_token_id) 96 yield text + ';' + decoded_word 97 98 99 def parse_known_args(argv): 100 """Parses args for the workflow.""" 101 parser = argparse.ArgumentParser() 102 parser.add_argument( 103 '--input', 104 dest='input', 105 help='Path to the text file containing sentences.') 106 parser.add_argument( 107 '--output', 108 dest='output', 109 required=True, 110 help='Path of file in which to save the output predictions.') 111 parser.add_argument( 112 '--bert_tokenizer', 113 dest='bert_tokenizer', 114 default='bert-base-uncased', 115 help='bert uncased model. This can be base model or large model') 116 parser.add_argument( 117 '--model_state_dict_path', 118 dest='model_state_dict_path', 119 required=True, 120 help="Path to the model's state_dict.") 121 return parser.parse_known_args(argv) 122 123 124 def run( 125 argv=None, 126 model_class=None, 127 model_params=None, 128 save_main_session=True, 129 test_pipeline=None) -> PipelineResult: 130 """ 131 Args: 132 argv: Command line arguments defined for this example. 133 model_class: Reference to the class definition of the model. 134 If None, BertForMaskedLM will be used as default . 135 model_params: Parameters passed to the constructor of the model_class. 136 These will be used to instantiate the model object in the 137 RunInference API. 138 save_main_session: Used for internal testing. 139 test_pipeline: Used for internal testing. 140 """ 141 known_args, pipeline_args = parse_known_args(argv) 142 pipeline_options = PipelineOptions(pipeline_args) 143 pipeline_options.view_as(SetupOptions).save_main_session = save_main_session 144 145 if not model_class: 146 model_config = BertConfig.from_pretrained( 147 known_args.bert_tokenizer, is_decoder=False, return_dict=True) 148 model_class = BertForMaskedLM 149 model_params = {'config': model_config} 150 151 # TODO: Remove once nested tensors https://github.com/pytorch/nestedtensor 152 # is officially released. 153 class PytorchNoBatchModelHandler(PytorchModelHandlerKeyedTensor): 154 """Wrapper to PytorchModelHandler to limit batch size to 1. 155 156 The tokenized strings generated from BertTokenizer may have different 157 lengths, which doesn't work with torch.stack() in current RunInference 158 implementation since stack() requires tensors to be the same size. 159 160 Restricting max_batch_size to 1 means there is only 1 example per `batch` 161 in the run_inference() call. 162 """ 163 def batch_elements_kwargs(self): 164 return {'max_batch_size': 1} 165 166 model_handler = PytorchNoBatchModelHandler( 167 state_dict_path=known_args.model_state_dict_path, 168 model_class=model_class, 169 model_params=model_params) 170 171 pipeline = test_pipeline 172 if not test_pipeline: 173 pipeline = beam.Pipeline(options=pipeline_options) 174 175 bert_tokenizer = BertTokenizer.from_pretrained(known_args.bert_tokenizer) 176 177 if not known_args.input: 178 text = (pipeline | 'CreateSentences' >> beam.Create([ 179 'The capital of France is Paris .', 180 'It is raining cats and dogs .', 181 'He looked up and saw the sun and stars .', 182 'Today is Monday and tomorrow is Tuesday .', 183 'There are 5 coconuts on this palm tree .', 184 'The richest person in the world is not here .', 185 'Malls are amazing places to shop because you can find everything you need under one roof .', # pylint: disable=line-too-long 186 'This audiobook is sure to liquefy your brain .', 187 'The secret ingredient to his wonderful life was gratitude .', 188 'The biggest animal in the world is the whale .', 189 ])) 190 else: 191 text = ( 192 pipeline | 'ReadSentences' >> beam.io.ReadFromText(known_args.input)) 193 text_and_tokenized_text_tuple = ( 194 text 195 | 'FilterEmptyLines' >> beam.ParDo(filter_empty_lines) 196 | 'AddMask' >> beam.Map(add_mask_to_last_word) 197 | 'TokenizeSentence' >> 198 beam.Map(lambda x: tokenize_sentence(x, bert_tokenizer))) 199 output = ( 200 text_and_tokenized_text_tuple 201 | 'PyTorchRunInference' >> RunInference(KeyedModelHandler(model_handler)) 202 | 'ProcessOutput' >> beam.ParDo( 203 PostProcessor(bert_tokenizer=bert_tokenizer))) 204 output | "WriteOutput" >> beam.io.WriteToText( # pylint: disable=expression-not-assigned 205 known_args.output, 206 shard_name_template='', 207 append_trailing_newlines=True) 208 209 result = pipeline.run() 210 result.wait_until_finish() 211 return result 212 213 214 if __name__ == '__main__': 215 logging.getLogger().setLevel(logging.INFO) 216 run()