github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/ml/inference/tensorflow_inference.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  # pytype: skip-file
    19  
    20  import enum
    21  import sys
    22  from typing import Any
    23  from typing import Callable
    24  from typing import Dict
    25  from typing import Iterable
    26  from typing import Optional
    27  from typing import Sequence
    28  from typing import Union
    29  
    30  import numpy
    31  
    32  import tensorflow as tf
    33  import tensorflow_hub as hub
    34  from apache_beam.ml.inference import utils
    35  from apache_beam.ml.inference.base import ModelHandler
    36  from apache_beam.ml.inference.base import PredictionResult
    37  
    38  __all__ = [
    39      'TFModelHandlerNumpy',
    40      'TFModelHandlerTensor',
    41  ]
    42  
    43  TensorInferenceFn = Callable[[
    44      tf.Module,
    45      Sequence[Union[numpy.ndarray, tf.Tensor]],
    46      Dict[str, Any],
    47      Optional[str]
    48  ],
    49                               Iterable[PredictionResult]]
    50  
    51  
    52  class ModelType(enum.Enum):
    53    """Defines how a model file should be loaded."""
    54    SAVED_MODEL = 1
    55    SAVED_WEIGHTS = 2
    56  
    57  
    58  def _load_model(model_uri, custom_weights, load_model_args):
    59    model = tf.keras.models.load_model(hub.resolve(model_uri), **load_model_args)
    60    if custom_weights:
    61      model.load_weights(custom_weights)
    62    return model
    63  
    64  
    65  def _load_model_from_weights(create_model_fn, weights_path):
    66    model = create_model_fn()
    67    model.load_weights(weights_path)
    68    return model
    69  
    70  
    71  def default_numpy_inference_fn(
    72      model: tf.Module,
    73      batch: Sequence[numpy.ndarray],
    74      inference_args: Dict[str, Any],
    75      model_id: Optional[str] = None) -> Iterable[PredictionResult]:
    76    vectorized_batch = numpy.stack(batch, axis=0)
    77    predictions = model(vectorized_batch, **inference_args)
    78    return utils._convert_to_result(batch, predictions, model_id)
    79  
    80  
    81  def default_tensor_inference_fn(
    82      model: tf.Module,
    83      batch: Sequence[tf.Tensor],
    84      inference_args: Dict[str, Any],
    85      model_id: Optional[str] = None) -> Iterable[PredictionResult]:
    86    vectorized_batch = tf.stack(batch, axis=0)
    87    predictions = model(vectorized_batch, **inference_args)
    88    return utils._convert_to_result(batch, predictions, model_id)
    89  
    90  
    91  class TFModelHandlerNumpy(ModelHandler[numpy.ndarray,
    92                                         PredictionResult,
    93                                         tf.Module]):
    94    def __init__(
    95        self,
    96        model_uri: str,
    97        model_type: ModelType = ModelType.SAVED_MODEL,
    98        create_model_fn: Optional[Callable] = None,
    99        *,
   100        load_model_args: Optional[Dict[str, Any]] = None,
   101        custom_weights: str = "",
   102        inference_fn: TensorInferenceFn = default_numpy_inference_fn,
   103        min_batch_size: Optional[int] = None,
   104        max_batch_size: Optional[int] = None,
   105        **kwargs):
   106      """Implementation of the ModelHandler interface for Tensorflow.
   107  
   108      Example Usage::
   109  
   110        pcoll | RunInference(TFModelHandlerNumpy(model_uri="my_uri"))
   111  
   112      See https://www.tensorflow.org/tutorials/keras/save_and_load for details.
   113  
   114      Args:
   115          model_uri (str): path to the trained model.
   116          model_type: type of model to be loaded. Defaults to SAVED_MODEL.
   117          create_model_fn: a function that creates and returns a new
   118            tensorflow model to load the saved weights.
   119            It should be used with ModelType.SAVED_WEIGHTS.
   120          load_model_args: a dictionary of parameters to pass to the load_model
   121            function of TensorFlow to specify custom config.
   122          custom_weights (str): path to the custom weights to be applied
   123            once the model is loaded.
   124          inference_fn: inference function to use during RunInference.
   125            Defaults to default_numpy_inference_fn.
   126          kwargs: 'env_vars' can be used to set environment variables
   127            before loading the model.
   128  
   129      **Supported Versions:** RunInference APIs in Apache Beam have been tested
   130      with Tensorflow 2.9, 2.10, 2.11.
   131      """
   132      self._model_uri = model_uri
   133      self._model_type = model_type
   134      self._inference_fn = inference_fn
   135      self._create_model_fn = create_model_fn
   136      self._env_vars = kwargs.get('env_vars', {})
   137      self._load_model_args = {} if not load_model_args else load_model_args
   138      self._custom_weights = custom_weights
   139      self._batching_kwargs = {}
   140      if min_batch_size is not None:
   141        self._batching_kwargs['min_batch_size'] = min_batch_size
   142      if max_batch_size is not None:
   143        self._batching_kwargs['max_batch_size'] = max_batch_size
   144  
   145    def load_model(self) -> tf.Module:
   146      """Loads and initializes a Tensorflow model for processing."""
   147      if self._model_type == ModelType.SAVED_WEIGHTS:
   148        if not self._create_model_fn:
   149          raise ValueError(
   150              "Callable create_model_fn must be passed"
   151              "with ModelType.SAVED_WEIGHTS")
   152        return _load_model_from_weights(self._create_model_fn, self._model_uri)
   153      return _load_model(
   154          self._model_uri, self._custom_weights, self._load_model_args)
   155  
   156    def update_model_path(self, model_path: Optional[str] = None):
   157      self._model_uri = model_path if model_path else self._model_uri
   158  
   159    def run_inference(
   160        self,
   161        batch: Sequence[numpy.ndarray],
   162        model: tf.Module,
   163        inference_args: Optional[Dict[str, Any]] = None
   164    ) -> Iterable[PredictionResult]:
   165      """
   166      Runs inferences on a batch of numpy array and returns an Iterable of
   167      numpy array Predictions.
   168  
   169      This method stacks the n-dimensional numpy array in a vectorized format to
   170      optimize the inference call.
   171  
   172      Args:
   173        batch: A sequence of numpy nd-array. These should be batchable, as this
   174          method will call `numpy.stack()` and pass in batched numpy nd-array
   175          with dimensions (batch_size, n_features, etc.) into the model's
   176          predict() function.
   177        model: A Tensorflow model.
   178        inference_args: any additional arguments for an inference.
   179  
   180      Returns:
   181        An Iterable of type PredictionResult.
   182      """
   183      inference_args = {} if not inference_args else inference_args
   184      return self._inference_fn(model, batch, inference_args, self._model_uri)
   185  
   186    def get_num_bytes(self, batch: Sequence[numpy.ndarray]) -> int:
   187      """
   188      Returns:
   189        The number of bytes of data for a batch of numpy arrays.
   190      """
   191      return sum(sys.getsizeof(element) for element in batch)
   192  
   193    def get_metrics_namespace(self) -> str:
   194      """
   195      Returns:
   196         A namespace for metrics collected by the RunInference transform.
   197      """
   198      return 'BeamML_TF_Numpy'
   199  
   200    def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
   201      pass
   202  
   203    def batch_elements_kwargs(self):
   204      return self._batching_kwargs
   205  
   206  
   207  class TFModelHandlerTensor(ModelHandler[tf.Tensor, PredictionResult,
   208                                          tf.Module]):
   209    def __init__(
   210        self,
   211        model_uri: str,
   212        model_type: ModelType = ModelType.SAVED_MODEL,
   213        create_model_fn: Optional[Callable] = None,
   214        *,
   215        load_model_args: Optional[Dict[str, Any]] = None,
   216        custom_weights: str = "",
   217        inference_fn: TensorInferenceFn = default_tensor_inference_fn,
   218        min_batch_size: Optional[int] = None,
   219        max_batch_size: Optional[int] = None,
   220        **kwargs):
   221      """Implementation of the ModelHandler interface for Tensorflow.
   222  
   223      Example Usage::
   224  
   225        pcoll | RunInference(TFModelHandlerTensor(model_uri="my_uri"))
   226  
   227      See https://www.tensorflow.org/tutorials/keras/save_and_load for details.
   228  
   229      Args:
   230          model_uri (str): path to the trained model.
   231          model_type: type of model to be loaded.
   232            Defaults to SAVED_MODEL.
   233          create_model_fn: a function that creates and returns a new
   234            tensorflow model to load the saved weights.
   235            It should be used with ModelType.SAVED_WEIGHTS.
   236          load_model_args: a dictionary of parameters to pass to the load_model
   237            function of TensorFlow to specify custom config.
   238          custom_weights (str): path to the custom weights to be applied
   239            once the model is loaded.
   240          inference_fn: inference function to use during RunInference.
   241            Defaults to default_numpy_inference_fn.
   242          kwargs: 'env_vars' can be used to set environment variables
   243            before loading the model.
   244  
   245      **Supported Versions:** RunInference APIs in Apache Beam have been tested
   246      with Tensorflow 2.11.
   247      """
   248      self._model_uri = model_uri
   249      self._model_type = model_type
   250      self._inference_fn = inference_fn
   251      self._create_model_fn = create_model_fn
   252      self._env_vars = kwargs.get('env_vars', {})
   253      self._load_model_args = {} if not load_model_args else load_model_args
   254      self._custom_weights = custom_weights
   255      self._batching_kwargs = {}
   256      if min_batch_size is not None:
   257        self._batching_kwargs['min_batch_size'] = min_batch_size
   258      if max_batch_size is not None:
   259        self._batching_kwargs['max_batch_size'] = max_batch_size
   260  
   261    def load_model(self) -> tf.Module:
   262      """Loads and initializes a tensorflow model for processing."""
   263      if self._model_type == ModelType.SAVED_WEIGHTS:
   264        if not self._create_model_fn:
   265          raise ValueError(
   266              "Callable create_model_fn must be passed"
   267              "with ModelType.SAVED_WEIGHTS")
   268        return _load_model_from_weights(self._create_model_fn, self._model_uri)
   269      return _load_model(
   270          self._model_uri, self._custom_weights, self._load_model_args)
   271  
   272    def update_model_path(self, model_path: Optional[str] = None):
   273      self._model_uri = model_path if model_path else self._model_uri
   274  
   275    def run_inference(
   276        self,
   277        batch: Sequence[tf.Tensor],
   278        model: tf.Module,
   279        inference_args: Optional[Dict[str, Any]] = None
   280    ) -> Iterable[PredictionResult]:
   281      """
   282      Runs inferences on a batch of tf.Tensor and returns an Iterable of
   283      Tensor Predictions.
   284  
   285      This method stacks the list of Tensors in a vectorized format to optimize
   286      the inference call.
   287  
   288      Args:
   289        batch: A sequence of Tensors. These Tensors should be batchable, as this
   290          method will call `tf.stack()` and pass in batched Tensors with
   291          dimensions (batch_size, n_features, etc.) into the model's predict()
   292          function.
   293        model: A Tensorflow model.
   294        inference_args: Non-batchable arguments required as inputs to the model's
   295          forward() function. Unlike Tensors in `batch`, these parameters will
   296          not be dynamically batched
   297      Returns:
   298        An Iterable of type PredictionResult.
   299      """
   300      inference_args = {} if not inference_args else inference_args
   301      return self._inference_fn(model, batch, inference_args, self._model_uri)
   302  
   303    def get_num_bytes(self, batch: Sequence[tf.Tensor]) -> int:
   304      """
   305      Returns:
   306        The number of bytes of data for a batch of Tensors.
   307      """
   308      return sum(sys.getsizeof(element) for element in batch)
   309  
   310    def get_metrics_namespace(self) -> str:
   311      """
   312      Returns:
   313         A namespace for metrics collected by the RunInference transform.
   314      """
   315      return 'BeamML_TF_Tensor'
   316  
   317    def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
   318      pass
   319  
   320    def batch_elements_kwargs(self):
   321      return self._batching_kwargs