github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/ml/inference/pytorch_inference.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 # pytype: skip-file 19 20 import logging 21 from collections import defaultdict 22 from typing import Any 23 from typing import Callable 24 from typing import Dict 25 from typing import Iterable 26 from typing import Optional 27 from typing import Sequence 28 29 import torch 30 from apache_beam.io.filesystems import FileSystems 31 from apache_beam.ml.inference import utils 32 from apache_beam.ml.inference.base import ModelHandler 33 from apache_beam.ml.inference.base import PredictionResult 34 35 __all__ = [ 36 'PytorchModelHandlerTensor', 37 'PytorchModelHandlerKeyedTensor', 38 ] 39 40 TensorInferenceFn = Callable[[ 41 Sequence[torch.Tensor], 42 torch.nn.Module, 43 torch.device, 44 Optional[Dict[str, Any]], 45 Optional[str] 46 ], 47 Iterable[PredictionResult]] 48 49 KeyedTensorInferenceFn = Callable[[ 50 Sequence[Dict[str, torch.Tensor]], 51 torch.nn.Module, 52 torch.device, 53 Optional[Dict[str, Any]], 54 Optional[str] 55 ], 56 Iterable[PredictionResult]] 57 58 59 def _validate_constructor_args( 60 state_dict_path, model_class, torch_script_model_path): 61 message = ( 62 "A {param1} has been supplied to the model " 63 "handler, but the required {param2} is missing. " 64 "Please provide the {param2} in order to " 65 "successfully load the {param1}.") 66 # state_dict_path and model_class are coupled with each other 67 # raise RuntimeError if user forgets to pass any one of them. 68 if state_dict_path and not model_class: 69 raise RuntimeError( 70 message.format(param1="state_dict_path", param2="model_class")) 71 72 if not state_dict_path and model_class: 73 raise RuntimeError( 74 message.format(param1="model_class", param2="state_dict_path")) 75 76 if torch_script_model_path and state_dict_path: 77 raise RuntimeError( 78 "Please specify either torch_script_model_path or " 79 "(state_dict_path, model_class) to successfully load the model.") 80 81 82 def _load_model( 83 model_class: Optional[Callable[..., torch.nn.Module]], 84 state_dict_path: Optional[str], 85 device: torch.device, 86 model_params: Optional[Dict[str, Any]], 87 torch_script_model_path: Optional[str], 88 load_model_args: Optional[Dict[str, Any]]): 89 if device == torch.device('cuda') and not torch.cuda.is_available(): 90 logging.warning( 91 "Model handler specified a 'GPU' device, but GPUs are not available. " 92 "Switching to CPU.") 93 device = torch.device('cpu') 94 95 try: 96 logging.info( 97 "Loading state_dict_path %s onto a %s device", state_dict_path, device) 98 if not torch_script_model_path: 99 file = FileSystems.open(state_dict_path, 'rb') 100 model = model_class(**model_params) # type: ignore[arg-type,misc] 101 state_dict = torch.load(file, map_location=device, **load_model_args) 102 model.load_state_dict(state_dict) 103 else: 104 file = FileSystems.open(torch_script_model_path, 'rb') 105 model = torch.jit.load(file, map_location=device, **load_model_args) 106 except RuntimeError as e: 107 if device == torch.device('cuda'): 108 message = "Loading the model onto a GPU device failed due to an " \ 109 f"exception:\n{e}\nAttempting to load onto a CPU device instead." 110 logging.warning(message) 111 return _load_model( 112 model_class, 113 state_dict_path, 114 torch.device('cpu'), 115 model_params, 116 torch_script_model_path, 117 load_model_args) 118 else: 119 raise e 120 121 model.to(device) 122 model.eval() 123 logging.info("Finished loading PyTorch model.") 124 return model, device 125 126 127 def _convert_to_device(examples: torch.Tensor, device) -> torch.Tensor: 128 """ 129 Converts samples to a style matching given device. 130 131 **NOTE:** A user may pass in device='GPU' but if GPU is not detected in the 132 environment it must be converted back to CPU. 133 """ 134 if examples.device != device: 135 examples = examples.to(device) 136 return examples 137 138 139 def default_tensor_inference_fn( 140 batch: Sequence[torch.Tensor], 141 model: torch.nn.Module, 142 device: str, 143 inference_args: Optional[Dict[str, Any]] = None, 144 model_id: Optional[str] = None, 145 ) -> Iterable[PredictionResult]: 146 # torch.no_grad() mitigates GPU memory issues 147 # https://github.com/apache/beam/issues/22811 148 with torch.no_grad(): 149 batched_tensors = torch.stack(batch) 150 batched_tensors = _convert_to_device(batched_tensors, device) 151 predictions = model(batched_tensors, **inference_args) 152 return utils._convert_to_result(batch, predictions, model_id) 153 154 155 def make_tensor_model_fn(model_fn: str) -> TensorInferenceFn: 156 """ 157 Produces a TensorInferenceFn that uses a method of the model other that 158 the forward() method. 159 160 Args: 161 model_fn: A string name of the method to be used. This is accessed through 162 getattr(model, model_fn) 163 """ 164 def attr_fn( 165 batch: Sequence[torch.Tensor], 166 model: torch.nn.Module, 167 device: str, 168 inference_args: Optional[Dict[str, Any]] = None, 169 model_id: Optional[str] = None, 170 ) -> Iterable[PredictionResult]: 171 with torch.no_grad(): 172 batched_tensors = torch.stack(batch) 173 batched_tensors = _convert_to_device(batched_tensors, device) 174 pred_fn = getattr(model, model_fn) 175 predictions = pred_fn(batched_tensors, **inference_args) 176 return utils._convert_to_result(batch, predictions, model_id) 177 178 return attr_fn 179 180 181 class PytorchModelHandlerTensor(ModelHandler[torch.Tensor, 182 PredictionResult, 183 torch.nn.Module]): 184 def __init__( 185 self, 186 state_dict_path: Optional[str] = None, 187 model_class: Optional[Callable[..., torch.nn.Module]] = None, 188 model_params: Optional[Dict[str, Any]] = None, 189 device: str = 'CPU', 190 *, 191 inference_fn: TensorInferenceFn = default_tensor_inference_fn, 192 torch_script_model_path: Optional[str] = None, 193 min_batch_size: Optional[int] = None, 194 max_batch_size: Optional[int] = None, 195 load_model_args: Optional[Dict[str, Any]] = None, 196 **kwargs): 197 """Implementation of the ModelHandler interface for PyTorch. 198 199 Example Usage for torch model:: 200 pcoll | RunInference(PytorchModelHandlerTensor(state_dict_path="my_uri", 201 model_class="my_class")) 202 Example Usage for torchscript model:: 203 pcoll | RunInference(PytorchModelHandlerTensor( 204 torch_script_model_path="my_uri")) 205 206 See https://pytorch.org/tutorials/beginner/saving_loading_models.html 207 for details 208 209 Args: 210 state_dict_path: path to the saved dictionary of the model state. 211 model_class: class of the Pytorch model that defines the model 212 structure. 213 model_params: A dictionary of arguments required to instantiate the model 214 class. 215 device: the device on which you wish to run the model. If 216 ``device = GPU`` then a GPU device will be used if it is available. 217 Otherwise, it will be CPU. 218 inference_fn: the inference function to use during RunInference. 219 default=_default_tensor_inference_fn 220 torch_script_model_path: Path to the torch script model. 221 the model will be loaded using `torch.jit.load()`. 222 `state_dict_path`, `model_class` and `model_params` 223 arguments will be disregarded. 224 min_batch_size: the minimum batch size to use when batching inputs. This 225 batch will be fed into the inference_fn as a Sequence of Tensors. 226 max_batch_size: the maximum batch size to use when batching inputs. This 227 batch will be fed into the inference_fn as a Sequence of Tensors. 228 load_model_args: a dictionary of parameters passed to the torch.load 229 function to specify custom config for loading models. 230 kwargs: 'env_vars' can be used to set environment variables 231 before loading the model. 232 233 **Supported Versions:** RunInference APIs in Apache Beam have been tested 234 with PyTorch 1.9 and 1.10. 235 """ 236 self._state_dict_path = state_dict_path 237 if device == 'GPU': 238 logging.info("Device is set to CUDA") 239 self._device = torch.device('cuda') 240 else: 241 logging.info("Device is set to CPU") 242 self._device = torch.device('cpu') 243 self._model_class = model_class 244 self._model_params = model_params if model_params else {} 245 self._inference_fn = inference_fn 246 self._batching_kwargs = {} 247 if min_batch_size is not None: 248 self._batching_kwargs['min_batch_size'] = min_batch_size 249 if max_batch_size is not None: 250 self._batching_kwargs['max_batch_size'] = max_batch_size 251 self._torch_script_model_path = torch_script_model_path 252 self._load_model_args = load_model_args if load_model_args else {} 253 self._env_vars = kwargs.get('env_vars', {}) 254 255 _validate_constructor_args( 256 state_dict_path=self._state_dict_path, 257 model_class=self._model_class, 258 torch_script_model_path=self._torch_script_model_path) 259 260 def load_model(self) -> torch.nn.Module: 261 """Loads and initializes a Pytorch model for processing.""" 262 model, device = _load_model( 263 model_class=self._model_class, 264 state_dict_path=self._state_dict_path, 265 device=self._device, 266 model_params=self._model_params, 267 torch_script_model_path=self._torch_script_model_path, 268 load_model_args=self._load_model_args 269 ) 270 self._device = device 271 return model 272 273 def update_model_path(self, model_path: Optional[str] = None): 274 if self._torch_script_model_path: 275 self._torch_script_model_path = ( 276 model_path if model_path else self._torch_script_model_path) 277 else: 278 self._state_dict_path = ( 279 model_path if model_path else self._state_dict_path) 280 281 def run_inference( 282 self, 283 batch: Sequence[torch.Tensor], 284 model: torch.nn.Module, 285 inference_args: Optional[Dict[str, Any]] = None 286 ) -> Iterable[PredictionResult]: 287 """ 288 Runs inferences on a batch of Tensors and returns an Iterable of 289 Tensor Predictions. 290 291 This method stacks the list of Tensors in a vectorized format to optimize 292 the inference call. 293 294 Args: 295 batch: A sequence of Tensors. These Tensors should be batchable, as this 296 method will call `torch.stack()` and pass in batched Tensors with 297 dimensions (batch_size, n_features, etc.) into the model's forward() 298 function. 299 model: A PyTorch model. 300 inference_args: Non-batchable arguments required as inputs to the model's 301 forward() function. Unlike Tensors in `batch`, these parameters will 302 not be dynamically batched 303 304 Returns: 305 An Iterable of type PredictionResult. 306 """ 307 inference_args = {} if not inference_args else inference_args 308 model_id = ( 309 self._state_dict_path 310 if not self._torch_script_model_path else self._torch_script_model_path) 311 return self._inference_fn( 312 batch, model, self._device, inference_args, model_id) 313 314 def get_num_bytes(self, batch: Sequence[torch.Tensor]) -> int: 315 """ 316 Returns: 317 The number of bytes of data for a batch of Tensors. 318 """ 319 return sum((el.element_size() for tensor in batch for el in tensor)) 320 321 def get_metrics_namespace(self) -> str: 322 """ 323 Returns: 324 A namespace for metrics collected by the RunInference transform. 325 """ 326 return 'BeamML_PyTorch' 327 328 def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]): 329 pass 330 331 def batch_elements_kwargs(self): 332 return self._batching_kwargs 333 334 335 def default_keyed_tensor_inference_fn( 336 batch: Sequence[Dict[str, torch.Tensor]], 337 model: torch.nn.Module, 338 device: str, 339 inference_args: Optional[Dict[str, Any]] = None, 340 model_id: Optional[str] = None, 341 ) -> Iterable[PredictionResult]: 342 # If elements in `batch` are provided as a dictionaries from key to Tensors, 343 # then iterate through the batch list, and group Tensors to the same key 344 key_to_tensor_list = defaultdict(list) 345 346 # torch.no_grad() mitigates GPU memory issues 347 # https://github.com/apache/beam/issues/22811 348 with torch.no_grad(): 349 for example in batch: 350 for key, tensor in example.items(): 351 key_to_tensor_list[key].append(tensor) 352 key_to_batched_tensors = {} 353 for key in key_to_tensor_list: 354 batched_tensors = torch.stack(key_to_tensor_list[key]) 355 batched_tensors = _convert_to_device(batched_tensors, device) 356 key_to_batched_tensors[key] = batched_tensors 357 predictions = model(**key_to_batched_tensors, **inference_args) 358 359 return utils._convert_to_result(batch, predictions, model_id) 360 361 362 def make_keyed_tensor_model_fn(model_fn: str) -> KeyedTensorInferenceFn: 363 """ 364 Produces a KeyedTensorInferenceFn that uses a method of the model other that 365 the forward() method. 366 367 Args: 368 model_fn: A string name of the method to be used. This is accessed through 369 getattr(model, model_fn) 370 """ 371 def attr_fn( 372 batch: Sequence[Dict[str, torch.Tensor]], 373 model: torch.nn.Module, 374 device: str, 375 inference_args: Optional[Dict[str, Any]] = None, 376 model_id: Optional[str] = None, 377 ) -> Iterable[PredictionResult]: 378 # If elements in `batch` are provided as a dictionaries from key to Tensors, 379 # then iterate through the batch list, and group Tensors to the same key 380 key_to_tensor_list = defaultdict(list) 381 382 # torch.no_grad() mitigates GPU memory issues 383 # https://github.com/apache/beam/issues/22811 384 with torch.no_grad(): 385 for example in batch: 386 for key, tensor in example.items(): 387 key_to_tensor_list[key].append(tensor) 388 key_to_batched_tensors = {} 389 for key in key_to_tensor_list: 390 batched_tensors = torch.stack(key_to_tensor_list[key]) 391 batched_tensors = _convert_to_device(batched_tensors, device) 392 key_to_batched_tensors[key] = batched_tensors 393 pred_fn = getattr(model, model_fn) 394 predictions = pred_fn(**key_to_batched_tensors, **inference_args) 395 return utils._convert_to_result(batch, predictions, model_id) 396 397 return attr_fn 398 399 400 class PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, torch.Tensor], 401 PredictionResult, 402 torch.nn.Module]): 403 def __init__( 404 self, 405 state_dict_path: Optional[str] = None, 406 model_class: Optional[Callable[..., torch.nn.Module]] = None, 407 model_params: Optional[Dict[str, Any]] = None, 408 device: str = 'CPU', 409 *, 410 inference_fn: KeyedTensorInferenceFn = default_keyed_tensor_inference_fn, 411 torch_script_model_path: Optional[str] = None, 412 min_batch_size: Optional[int] = None, 413 max_batch_size: Optional[int] = None, 414 load_model_args: Optional[Dict[str, Any]] = None, 415 **kwargs): 416 """Implementation of the ModelHandler interface for PyTorch. 417 418 Example Usage for torch model:: 419 pcoll | RunInference(PytorchModelHandlerKeyedTensor( 420 state_dict_path="my_uri", 421 model_class="my_class")) 422 423 Example Usage for torchscript model:: 424 pcoll | RunInference(PytorchModelHandlerKeyedTensor( 425 torch_script_model_path="my_uri")) 426 427 **NOTE:** This API and its implementation are under development and 428 do not provide backward compatibility guarantees. 429 430 See https://pytorch.org/tutorials/beginner/saving_loading_models.html 431 for details 432 433 Args: 434 state_dict_path: path to the saved dictionary of the model state. 435 model_class: class of the Pytorch model that defines the model 436 structure. 437 model_params: A dictionary of arguments required to instantiate the model 438 class. 439 device: the device on which you wish to run the model. If 440 ``device = GPU`` then a GPU device will be used if it is available. 441 Otherwise, it will be CPU. 442 inference_fn: the function to invoke on run_inference. 443 default = default_keyed_tensor_inference_fn 444 torch_script_model_path: Path to the torch script model. 445 the model will be loaded using `torch.jit.load()`. 446 `state_dict_path`, `model_class` and `model_params` 447 arguments will be disregarded. 448 min_batch_size: the minimum batch size to use when batching inputs. This 449 batch will be fed into the inference_fn as a Sequence of Keyed Tensors. 450 max_batch_size: the maximum batch size to use when batching inputs. This 451 batch will be fed into the inference_fn as a Sequence of Keyed Tensors. 452 load_model_args: a dictionary of parameters passed to the torch.load 453 function to specify custom config for loading models. 454 kwargs: 'env_vars' can be used to set environment variables 455 before loading the model. 456 457 **Supported Versions:** RunInference APIs in Apache Beam have been tested 458 on torch>=1.9.0,<1.14.0. 459 """ 460 self._state_dict_path = state_dict_path 461 if device == 'GPU': 462 logging.info("Device is set to CUDA") 463 self._device = torch.device('cuda') 464 else: 465 logging.info("Device is set to CPU") 466 self._device = torch.device('cpu') 467 self._model_class = model_class 468 self._model_params = model_params if model_params else {} 469 self._inference_fn = inference_fn 470 self._batching_kwargs = {} 471 if min_batch_size is not None: 472 self._batching_kwargs['min_batch_size'] = min_batch_size 473 if max_batch_size is not None: 474 self._batching_kwargs['max_batch_size'] = max_batch_size 475 self._torch_script_model_path = torch_script_model_path 476 self._load_model_args = load_model_args if load_model_args else {} 477 self._env_vars = kwargs.get('env_vars', {}) 478 _validate_constructor_args( 479 state_dict_path=self._state_dict_path, 480 model_class=self._model_class, 481 torch_script_model_path=self._torch_script_model_path) 482 483 def load_model(self) -> torch.nn.Module: 484 """Loads and initializes a Pytorch model for processing.""" 485 model, device = _load_model( 486 model_class=self._model_class, 487 state_dict_path=self._state_dict_path, 488 device=self._device, 489 model_params=self._model_params, 490 torch_script_model_path=self._torch_script_model_path, 491 load_model_args=self._load_model_args 492 ) 493 self._device = device 494 return model 495 496 def update_model_path(self, model_path: Optional[str] = None): 497 if self._torch_script_model_path: 498 self._torch_script_model_path = ( 499 model_path if model_path else self._torch_script_model_path) 500 else: 501 self._state_dict_path = ( 502 model_path if model_path else self._state_dict_path) 503 504 def run_inference( 505 self, 506 batch: Sequence[Dict[str, torch.Tensor]], 507 model: torch.nn.Module, 508 inference_args: Optional[Dict[str, Any]] = None 509 ) -> Iterable[PredictionResult]: 510 """ 511 Runs inferences on a batch of Keyed Tensors and returns an Iterable of 512 Tensor Predictions. 513 514 For the same key across all examples, this will stack all Tensors values 515 in a vectorized format to optimize the inference call. 516 517 Args: 518 batch: A sequence of keyed Tensors. These Tensors should be batchable, 519 as this method will call `torch.stack()` and pass in batched Tensors 520 with dimensions (batch_size, n_features, etc.) into the model's 521 forward() function. 522 model: A PyTorch model. 523 inference_args: Non-batchable arguments required as inputs to the model's 524 forward() function. Unlike Tensors in `batch`, these parameters will 525 not be dynamically batched 526 527 Returns: 528 An Iterable of type PredictionResult. 529 """ 530 inference_args = {} if not inference_args else inference_args 531 model_id = ( 532 self._state_dict_path 533 if not self._torch_script_model_path else self._torch_script_model_path) 534 return self._inference_fn( 535 batch, model, self._device, inference_args, model_id) 536 537 def get_num_bytes(self, batch: Sequence[torch.Tensor]) -> int: 538 """ 539 Returns: 540 The number of bytes of data for a batch of Dict of Tensors. 541 """ 542 # If elements in `batch` are provided as a dictionaries from key to Tensors 543 return sum( 544 (el.element_size() for tensor in batch for el in tensor.values())) 545 546 def get_metrics_namespace(self) -> str: 547 """ 548 Returns: 549 A namespace for metrics collected by the RunInference transform. 550 """ 551 return 'BeamML_PyTorch' 552 553 def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]): 554 pass 555 556 def batch_elements_kwargs(self): 557 return self._batching_kwargs