github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/ml/inference/pytorch_inference.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  # pytype: skip-file
    19  
    20  import logging
    21  from collections import defaultdict
    22  from typing import Any
    23  from typing import Callable
    24  from typing import Dict
    25  from typing import Iterable
    26  from typing import Optional
    27  from typing import Sequence
    28  
    29  import torch
    30  from apache_beam.io.filesystems import FileSystems
    31  from apache_beam.ml.inference import utils
    32  from apache_beam.ml.inference.base import ModelHandler
    33  from apache_beam.ml.inference.base import PredictionResult
    34  
    35  __all__ = [
    36      'PytorchModelHandlerTensor',
    37      'PytorchModelHandlerKeyedTensor',
    38  ]
    39  
    40  TensorInferenceFn = Callable[[
    41      Sequence[torch.Tensor],
    42      torch.nn.Module,
    43      torch.device,
    44      Optional[Dict[str, Any]],
    45      Optional[str]
    46  ],
    47                               Iterable[PredictionResult]]
    48  
    49  KeyedTensorInferenceFn = Callable[[
    50      Sequence[Dict[str, torch.Tensor]],
    51      torch.nn.Module,
    52      torch.device,
    53      Optional[Dict[str, Any]],
    54      Optional[str]
    55  ],
    56                                    Iterable[PredictionResult]]
    57  
    58  
    59  def _validate_constructor_args(
    60      state_dict_path, model_class, torch_script_model_path):
    61    message = (
    62        "A {param1} has been supplied to the model "
    63        "handler, but the required {param2} is missing. "
    64        "Please provide the {param2} in order to "
    65        "successfully load the {param1}.")
    66    # state_dict_path and model_class are coupled with each other
    67    # raise RuntimeError if user forgets to pass any one of them.
    68    if state_dict_path and not model_class:
    69      raise RuntimeError(
    70          message.format(param1="state_dict_path", param2="model_class"))
    71  
    72    if not state_dict_path and model_class:
    73      raise RuntimeError(
    74          message.format(param1="model_class", param2="state_dict_path"))
    75  
    76    if torch_script_model_path and state_dict_path:
    77      raise RuntimeError(
    78          "Please specify either torch_script_model_path or "
    79          "(state_dict_path, model_class) to successfully load the model.")
    80  
    81  
    82  def _load_model(
    83      model_class: Optional[Callable[..., torch.nn.Module]],
    84      state_dict_path: Optional[str],
    85      device: torch.device,
    86      model_params: Optional[Dict[str, Any]],
    87      torch_script_model_path: Optional[str],
    88      load_model_args: Optional[Dict[str, Any]]):
    89    if device == torch.device('cuda') and not torch.cuda.is_available():
    90      logging.warning(
    91          "Model handler specified a 'GPU' device, but GPUs are not available. "
    92          "Switching to CPU.")
    93      device = torch.device('cpu')
    94  
    95    try:
    96      logging.info(
    97          "Loading state_dict_path %s onto a %s device", state_dict_path, device)
    98      if not torch_script_model_path:
    99        file = FileSystems.open(state_dict_path, 'rb')
   100        model = model_class(**model_params)  # type: ignore[arg-type,misc]
   101        state_dict = torch.load(file, map_location=device, **load_model_args)
   102        model.load_state_dict(state_dict)
   103      else:
   104        file = FileSystems.open(torch_script_model_path, 'rb')
   105        model = torch.jit.load(file, map_location=device, **load_model_args)
   106    except RuntimeError as e:
   107      if device == torch.device('cuda'):
   108        message = "Loading the model onto a GPU device failed due to an " \
   109          f"exception:\n{e}\nAttempting to load onto a CPU device instead."
   110        logging.warning(message)
   111        return _load_model(
   112            model_class,
   113            state_dict_path,
   114            torch.device('cpu'),
   115            model_params,
   116            torch_script_model_path,
   117            load_model_args)
   118      else:
   119        raise e
   120  
   121    model.to(device)
   122    model.eval()
   123    logging.info("Finished loading PyTorch model.")
   124    return model, device
   125  
   126  
   127  def _convert_to_device(examples: torch.Tensor, device) -> torch.Tensor:
   128    """
   129    Converts samples to a style matching given device.
   130  
   131    **NOTE:** A user may pass in device='GPU' but if GPU is not detected in the
   132    environment it must be converted back to CPU.
   133    """
   134    if examples.device != device:
   135      examples = examples.to(device)
   136    return examples
   137  
   138  
   139  def default_tensor_inference_fn(
   140      batch: Sequence[torch.Tensor],
   141      model: torch.nn.Module,
   142      device: str,
   143      inference_args: Optional[Dict[str, Any]] = None,
   144      model_id: Optional[str] = None,
   145  ) -> Iterable[PredictionResult]:
   146    # torch.no_grad() mitigates GPU memory issues
   147    # https://github.com/apache/beam/issues/22811
   148    with torch.no_grad():
   149      batched_tensors = torch.stack(batch)
   150      batched_tensors = _convert_to_device(batched_tensors, device)
   151      predictions = model(batched_tensors, **inference_args)
   152      return utils._convert_to_result(batch, predictions, model_id)
   153  
   154  
   155  def make_tensor_model_fn(model_fn: str) -> TensorInferenceFn:
   156    """
   157    Produces a TensorInferenceFn that uses a method of the model other that
   158    the forward() method.
   159  
   160    Args:
   161      model_fn: A string name of the method to be used. This is accessed through
   162        getattr(model, model_fn)
   163    """
   164    def attr_fn(
   165        batch: Sequence[torch.Tensor],
   166        model: torch.nn.Module,
   167        device: str,
   168        inference_args: Optional[Dict[str, Any]] = None,
   169        model_id: Optional[str] = None,
   170    ) -> Iterable[PredictionResult]:
   171      with torch.no_grad():
   172        batched_tensors = torch.stack(batch)
   173        batched_tensors = _convert_to_device(batched_tensors, device)
   174        pred_fn = getattr(model, model_fn)
   175        predictions = pred_fn(batched_tensors, **inference_args)
   176        return utils._convert_to_result(batch, predictions, model_id)
   177  
   178    return attr_fn
   179  
   180  
   181  class PytorchModelHandlerTensor(ModelHandler[torch.Tensor,
   182                                               PredictionResult,
   183                                               torch.nn.Module]):
   184    def __init__(
   185        self,
   186        state_dict_path: Optional[str] = None,
   187        model_class: Optional[Callable[..., torch.nn.Module]] = None,
   188        model_params: Optional[Dict[str, Any]] = None,
   189        device: str = 'CPU',
   190        *,
   191        inference_fn: TensorInferenceFn = default_tensor_inference_fn,
   192        torch_script_model_path: Optional[str] = None,
   193        min_batch_size: Optional[int] = None,
   194        max_batch_size: Optional[int] = None,
   195        load_model_args: Optional[Dict[str, Any]] = None,
   196        **kwargs):
   197      """Implementation of the ModelHandler interface for PyTorch.
   198  
   199      Example Usage for torch model::
   200        pcoll | RunInference(PytorchModelHandlerTensor(state_dict_path="my_uri",
   201                                                       model_class="my_class"))
   202      Example Usage for torchscript model::
   203        pcoll | RunInference(PytorchModelHandlerTensor(
   204          torch_script_model_path="my_uri"))
   205  
   206      See https://pytorch.org/tutorials/beginner/saving_loading_models.html
   207      for details
   208  
   209      Args:
   210        state_dict_path: path to the saved dictionary of the model state.
   211        model_class: class of the Pytorch model that defines the model
   212          structure.
   213        model_params: A dictionary of arguments required to instantiate the model
   214          class.
   215        device: the device on which you wish to run the model. If
   216          ``device = GPU`` then a GPU device will be used if it is available.
   217          Otherwise, it will be CPU.
   218        inference_fn: the inference function to use during RunInference.
   219          default=_default_tensor_inference_fn
   220        torch_script_model_path: Path to the torch script model.
   221           the model will be loaded using `torch.jit.load()`.
   222          `state_dict_path`, `model_class` and `model_params`
   223           arguments will be disregarded.
   224        min_batch_size: the minimum batch size to use when batching inputs. This
   225          batch will be fed into the inference_fn as a Sequence of Tensors.
   226        max_batch_size: the maximum batch size to use when batching inputs. This
   227          batch will be fed into the inference_fn as a Sequence of Tensors.
   228        load_model_args: a dictionary of parameters passed to the torch.load
   229          function to specify custom config for loading models.
   230        kwargs: 'env_vars' can be used to set environment variables
   231          before loading the model.
   232  
   233      **Supported Versions:** RunInference APIs in Apache Beam have been tested
   234      with PyTorch 1.9 and 1.10.
   235      """
   236      self._state_dict_path = state_dict_path
   237      if device == 'GPU':
   238        logging.info("Device is set to CUDA")
   239        self._device = torch.device('cuda')
   240      else:
   241        logging.info("Device is set to CPU")
   242        self._device = torch.device('cpu')
   243      self._model_class = model_class
   244      self._model_params = model_params if model_params else {}
   245      self._inference_fn = inference_fn
   246      self._batching_kwargs = {}
   247      if min_batch_size is not None:
   248        self._batching_kwargs['min_batch_size'] = min_batch_size
   249      if max_batch_size is not None:
   250        self._batching_kwargs['max_batch_size'] = max_batch_size
   251      self._torch_script_model_path = torch_script_model_path
   252      self._load_model_args = load_model_args if load_model_args else {}
   253      self._env_vars = kwargs.get('env_vars', {})
   254  
   255      _validate_constructor_args(
   256          state_dict_path=self._state_dict_path,
   257          model_class=self._model_class,
   258          torch_script_model_path=self._torch_script_model_path)
   259  
   260    def load_model(self) -> torch.nn.Module:
   261      """Loads and initializes a Pytorch model for processing."""
   262      model, device = _load_model(
   263          model_class=self._model_class,
   264          state_dict_path=self._state_dict_path,
   265          device=self._device,
   266          model_params=self._model_params,
   267          torch_script_model_path=self._torch_script_model_path,
   268          load_model_args=self._load_model_args
   269      )
   270      self._device = device
   271      return model
   272  
   273    def update_model_path(self, model_path: Optional[str] = None):
   274      if self._torch_script_model_path:
   275        self._torch_script_model_path = (
   276            model_path if model_path else self._torch_script_model_path)
   277      else:
   278        self._state_dict_path = (
   279            model_path if model_path else self._state_dict_path)
   280  
   281    def run_inference(
   282        self,
   283        batch: Sequence[torch.Tensor],
   284        model: torch.nn.Module,
   285        inference_args: Optional[Dict[str, Any]] = None
   286    ) -> Iterable[PredictionResult]:
   287      """
   288      Runs inferences on a batch of Tensors and returns an Iterable of
   289      Tensor Predictions.
   290  
   291      This method stacks the list of Tensors in a vectorized format to optimize
   292      the inference call.
   293  
   294      Args:
   295        batch: A sequence of Tensors. These Tensors should be batchable, as this
   296          method will call `torch.stack()` and pass in batched Tensors with
   297          dimensions (batch_size, n_features, etc.) into the model's forward()
   298          function.
   299        model: A PyTorch model.
   300        inference_args: Non-batchable arguments required as inputs to the model's
   301          forward() function. Unlike Tensors in `batch`, these parameters will
   302          not be dynamically batched
   303  
   304      Returns:
   305        An Iterable of type PredictionResult.
   306      """
   307      inference_args = {} if not inference_args else inference_args
   308      model_id = (
   309          self._state_dict_path
   310          if not self._torch_script_model_path else self._torch_script_model_path)
   311      return self._inference_fn(
   312          batch, model, self._device, inference_args, model_id)
   313  
   314    def get_num_bytes(self, batch: Sequence[torch.Tensor]) -> int:
   315      """
   316      Returns:
   317        The number of bytes of data for a batch of Tensors.
   318      """
   319      return sum((el.element_size() for tensor in batch for el in tensor))
   320  
   321    def get_metrics_namespace(self) -> str:
   322      """
   323      Returns:
   324         A namespace for metrics collected by the RunInference transform.
   325      """
   326      return 'BeamML_PyTorch'
   327  
   328    def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
   329      pass
   330  
   331    def batch_elements_kwargs(self):
   332      return self._batching_kwargs
   333  
   334  
   335  def default_keyed_tensor_inference_fn(
   336      batch: Sequence[Dict[str, torch.Tensor]],
   337      model: torch.nn.Module,
   338      device: str,
   339      inference_args: Optional[Dict[str, Any]] = None,
   340      model_id: Optional[str] = None,
   341  ) -> Iterable[PredictionResult]:
   342    # If elements in `batch` are provided as a dictionaries from key to Tensors,
   343    # then iterate through the batch list, and group Tensors to the same key
   344    key_to_tensor_list = defaultdict(list)
   345  
   346    # torch.no_grad() mitigates GPU memory issues
   347    # https://github.com/apache/beam/issues/22811
   348    with torch.no_grad():
   349      for example in batch:
   350        for key, tensor in example.items():
   351          key_to_tensor_list[key].append(tensor)
   352      key_to_batched_tensors = {}
   353      for key in key_to_tensor_list:
   354        batched_tensors = torch.stack(key_to_tensor_list[key])
   355        batched_tensors = _convert_to_device(batched_tensors, device)
   356        key_to_batched_tensors[key] = batched_tensors
   357      predictions = model(**key_to_batched_tensors, **inference_args)
   358  
   359      return utils._convert_to_result(batch, predictions, model_id)
   360  
   361  
   362  def make_keyed_tensor_model_fn(model_fn: str) -> KeyedTensorInferenceFn:
   363    """
   364    Produces a KeyedTensorInferenceFn that uses a method of the model other that
   365    the forward() method.
   366  
   367    Args:
   368      model_fn: A string name of the method to be used. This is accessed through
   369        getattr(model, model_fn)
   370    """
   371    def attr_fn(
   372        batch: Sequence[Dict[str, torch.Tensor]],
   373        model: torch.nn.Module,
   374        device: str,
   375        inference_args: Optional[Dict[str, Any]] = None,
   376        model_id: Optional[str] = None,
   377    ) -> Iterable[PredictionResult]:
   378      # If elements in `batch` are provided as a dictionaries from key to Tensors,
   379      # then iterate through the batch list, and group Tensors to the same key
   380      key_to_tensor_list = defaultdict(list)
   381  
   382      # torch.no_grad() mitigates GPU memory issues
   383      # https://github.com/apache/beam/issues/22811
   384      with torch.no_grad():
   385        for example in batch:
   386          for key, tensor in example.items():
   387            key_to_tensor_list[key].append(tensor)
   388        key_to_batched_tensors = {}
   389        for key in key_to_tensor_list:
   390          batched_tensors = torch.stack(key_to_tensor_list[key])
   391          batched_tensors = _convert_to_device(batched_tensors, device)
   392          key_to_batched_tensors[key] = batched_tensors
   393        pred_fn = getattr(model, model_fn)
   394        predictions = pred_fn(**key_to_batched_tensors, **inference_args)
   395      return utils._convert_to_result(batch, predictions, model_id)
   396  
   397    return attr_fn
   398  
   399  
   400  class PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, torch.Tensor],
   401                                                    PredictionResult,
   402                                                    torch.nn.Module]):
   403    def __init__(
   404        self,
   405        state_dict_path: Optional[str] = None,
   406        model_class: Optional[Callable[..., torch.nn.Module]] = None,
   407        model_params: Optional[Dict[str, Any]] = None,
   408        device: str = 'CPU',
   409        *,
   410        inference_fn: KeyedTensorInferenceFn = default_keyed_tensor_inference_fn,
   411        torch_script_model_path: Optional[str] = None,
   412        min_batch_size: Optional[int] = None,
   413        max_batch_size: Optional[int] = None,
   414        load_model_args: Optional[Dict[str, Any]] = None,
   415        **kwargs):
   416      """Implementation of the ModelHandler interface for PyTorch.
   417  
   418       Example Usage for torch model::
   419        pcoll | RunInference(PytorchModelHandlerKeyedTensor(
   420          state_dict_path="my_uri",
   421          model_class="my_class"))
   422  
   423      Example Usage for torchscript model::
   424        pcoll | RunInference(PytorchModelHandlerKeyedTensor(
   425          torch_script_model_path="my_uri"))
   426  
   427      **NOTE:** This API and its implementation are under development and
   428      do not provide backward compatibility guarantees.
   429  
   430      See https://pytorch.org/tutorials/beginner/saving_loading_models.html
   431      for details
   432  
   433      Args:
   434        state_dict_path: path to the saved dictionary of the model state.
   435        model_class: class of the Pytorch model that defines the model
   436          structure.
   437        model_params: A dictionary of arguments required to instantiate the model
   438          class.
   439        device: the device on which you wish to run the model. If
   440          ``device = GPU`` then a GPU device will be used if it is available.
   441          Otherwise, it will be CPU.
   442        inference_fn: the function to invoke on run_inference.
   443          default = default_keyed_tensor_inference_fn
   444        torch_script_model_path: Path to the torch script model.
   445           the model will be loaded using `torch.jit.load()`.
   446          `state_dict_path`, `model_class` and `model_params`
   447           arguments will be disregarded.
   448        min_batch_size: the minimum batch size to use when batching inputs. This
   449          batch will be fed into the inference_fn as a Sequence of Keyed Tensors.
   450        max_batch_size: the maximum batch size to use when batching inputs. This
   451          batch will be fed into the inference_fn as a Sequence of Keyed Tensors.
   452        load_model_args: a dictionary of parameters passed to the torch.load
   453          function to specify custom config for loading models.
   454        kwargs: 'env_vars' can be used to set environment variables
   455          before loading the model.
   456  
   457      **Supported Versions:** RunInference APIs in Apache Beam have been tested
   458      on torch>=1.9.0,<1.14.0.
   459      """
   460      self._state_dict_path = state_dict_path
   461      if device == 'GPU':
   462        logging.info("Device is set to CUDA")
   463        self._device = torch.device('cuda')
   464      else:
   465        logging.info("Device is set to CPU")
   466        self._device = torch.device('cpu')
   467      self._model_class = model_class
   468      self._model_params = model_params if model_params else {}
   469      self._inference_fn = inference_fn
   470      self._batching_kwargs = {}
   471      if min_batch_size is not None:
   472        self._batching_kwargs['min_batch_size'] = min_batch_size
   473      if max_batch_size is not None:
   474        self._batching_kwargs['max_batch_size'] = max_batch_size
   475      self._torch_script_model_path = torch_script_model_path
   476      self._load_model_args = load_model_args if load_model_args else {}
   477      self._env_vars = kwargs.get('env_vars', {})
   478      _validate_constructor_args(
   479          state_dict_path=self._state_dict_path,
   480          model_class=self._model_class,
   481          torch_script_model_path=self._torch_script_model_path)
   482  
   483    def load_model(self) -> torch.nn.Module:
   484      """Loads and initializes a Pytorch model for processing."""
   485      model, device = _load_model(
   486          model_class=self._model_class,
   487          state_dict_path=self._state_dict_path,
   488          device=self._device,
   489          model_params=self._model_params,
   490          torch_script_model_path=self._torch_script_model_path,
   491          load_model_args=self._load_model_args
   492      )
   493      self._device = device
   494      return model
   495  
   496    def update_model_path(self, model_path: Optional[str] = None):
   497      if self._torch_script_model_path:
   498        self._torch_script_model_path = (
   499            model_path if model_path else self._torch_script_model_path)
   500      else:
   501        self._state_dict_path = (
   502            model_path if model_path else self._state_dict_path)
   503  
   504    def run_inference(
   505        self,
   506        batch: Sequence[Dict[str, torch.Tensor]],
   507        model: torch.nn.Module,
   508        inference_args: Optional[Dict[str, Any]] = None
   509    ) -> Iterable[PredictionResult]:
   510      """
   511      Runs inferences on a batch of Keyed Tensors and returns an Iterable of
   512      Tensor Predictions.
   513  
   514      For the same key across all examples, this will stack all Tensors values
   515      in a vectorized format to optimize the inference call.
   516  
   517      Args:
   518        batch: A sequence of keyed Tensors. These Tensors should be batchable,
   519          as this method will call `torch.stack()` and pass in batched Tensors
   520          with dimensions (batch_size, n_features, etc.) into the model's
   521          forward() function.
   522        model: A PyTorch model.
   523        inference_args: Non-batchable arguments required as inputs to the model's
   524          forward() function. Unlike Tensors in `batch`, these parameters will
   525          not be dynamically batched
   526  
   527      Returns:
   528        An Iterable of type PredictionResult.
   529      """
   530      inference_args = {} if not inference_args else inference_args
   531      model_id = (
   532          self._state_dict_path
   533          if not self._torch_script_model_path else self._torch_script_model_path)
   534      return self._inference_fn(
   535          batch, model, self._device, inference_args, model_id)
   536  
   537    def get_num_bytes(self, batch: Sequence[torch.Tensor]) -> int:
   538      """
   539      Returns:
   540         The number of bytes of data for a batch of Dict of Tensors.
   541      """
   542      # If elements in `batch` are provided as a dictionaries from key to Tensors
   543      return sum(
   544          (el.element_size() for tensor in batch for el in tensor.values()))
   545  
   546    def get_metrics_namespace(self) -> str:
   547      """
   548      Returns:
   549         A namespace for metrics collected by the RunInference transform.
   550      """
   551      return 'BeamML_PyTorch'
   552  
   553    def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
   554      pass
   555  
   556    def batch_elements_kwargs(self):
   557      return self._batching_kwargs