github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/ml/inference/tensorflow_inference.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 # pytype: skip-file 19 20 import enum 21 import sys 22 from typing import Any 23 from typing import Callable 24 from typing import Dict 25 from typing import Iterable 26 from typing import Optional 27 from typing import Sequence 28 from typing import Union 29 30 import numpy 31 32 import tensorflow as tf 33 import tensorflow_hub as hub 34 from apache_beam.ml.inference import utils 35 from apache_beam.ml.inference.base import ModelHandler 36 from apache_beam.ml.inference.base import PredictionResult 37 38 __all__ = [ 39 'TFModelHandlerNumpy', 40 'TFModelHandlerTensor', 41 ] 42 43 TensorInferenceFn = Callable[[ 44 tf.Module, 45 Sequence[Union[numpy.ndarray, tf.Tensor]], 46 Dict[str, Any], 47 Optional[str] 48 ], 49 Iterable[PredictionResult]] 50 51 52 class ModelType(enum.Enum): 53 """Defines how a model file should be loaded.""" 54 SAVED_MODEL = 1 55 SAVED_WEIGHTS = 2 56 57 58 def _load_model(model_uri, custom_weights, load_model_args): 59 model = tf.keras.models.load_model(hub.resolve(model_uri), **load_model_args) 60 if custom_weights: 61 model.load_weights(custom_weights) 62 return model 63 64 65 def _load_model_from_weights(create_model_fn, weights_path): 66 model = create_model_fn() 67 model.load_weights(weights_path) 68 return model 69 70 71 def default_numpy_inference_fn( 72 model: tf.Module, 73 batch: Sequence[numpy.ndarray], 74 inference_args: Dict[str, Any], 75 model_id: Optional[str] = None) -> Iterable[PredictionResult]: 76 vectorized_batch = numpy.stack(batch, axis=0) 77 predictions = model(vectorized_batch, **inference_args) 78 return utils._convert_to_result(batch, predictions, model_id) 79 80 81 def default_tensor_inference_fn( 82 model: tf.Module, 83 batch: Sequence[tf.Tensor], 84 inference_args: Dict[str, Any], 85 model_id: Optional[str] = None) -> Iterable[PredictionResult]: 86 vectorized_batch = tf.stack(batch, axis=0) 87 predictions = model(vectorized_batch, **inference_args) 88 return utils._convert_to_result(batch, predictions, model_id) 89 90 91 class TFModelHandlerNumpy(ModelHandler[numpy.ndarray, 92 PredictionResult, 93 tf.Module]): 94 def __init__( 95 self, 96 model_uri: str, 97 model_type: ModelType = ModelType.SAVED_MODEL, 98 create_model_fn: Optional[Callable] = None, 99 *, 100 load_model_args: Optional[Dict[str, Any]] = None, 101 custom_weights: str = "", 102 inference_fn: TensorInferenceFn = default_numpy_inference_fn, 103 min_batch_size: Optional[int] = None, 104 max_batch_size: Optional[int] = None, 105 **kwargs): 106 """Implementation of the ModelHandler interface for Tensorflow. 107 108 Example Usage:: 109 110 pcoll | RunInference(TFModelHandlerNumpy(model_uri="my_uri")) 111 112 See https://www.tensorflow.org/tutorials/keras/save_and_load for details. 113 114 Args: 115 model_uri (str): path to the trained model. 116 model_type: type of model to be loaded. Defaults to SAVED_MODEL. 117 create_model_fn: a function that creates and returns a new 118 tensorflow model to load the saved weights. 119 It should be used with ModelType.SAVED_WEIGHTS. 120 load_model_args: a dictionary of parameters to pass to the load_model 121 function of TensorFlow to specify custom config. 122 custom_weights (str): path to the custom weights to be applied 123 once the model is loaded. 124 inference_fn: inference function to use during RunInference. 125 Defaults to default_numpy_inference_fn. 126 kwargs: 'env_vars' can be used to set environment variables 127 before loading the model. 128 129 **Supported Versions:** RunInference APIs in Apache Beam have been tested 130 with Tensorflow 2.9, 2.10, 2.11. 131 """ 132 self._model_uri = model_uri 133 self._model_type = model_type 134 self._inference_fn = inference_fn 135 self._create_model_fn = create_model_fn 136 self._env_vars = kwargs.get('env_vars', {}) 137 self._load_model_args = {} if not load_model_args else load_model_args 138 self._custom_weights = custom_weights 139 self._batching_kwargs = {} 140 if min_batch_size is not None: 141 self._batching_kwargs['min_batch_size'] = min_batch_size 142 if max_batch_size is not None: 143 self._batching_kwargs['max_batch_size'] = max_batch_size 144 145 def load_model(self) -> tf.Module: 146 """Loads and initializes a Tensorflow model for processing.""" 147 if self._model_type == ModelType.SAVED_WEIGHTS: 148 if not self._create_model_fn: 149 raise ValueError( 150 "Callable create_model_fn must be passed" 151 "with ModelType.SAVED_WEIGHTS") 152 return _load_model_from_weights(self._create_model_fn, self._model_uri) 153 return _load_model( 154 self._model_uri, self._custom_weights, self._load_model_args) 155 156 def update_model_path(self, model_path: Optional[str] = None): 157 self._model_uri = model_path if model_path else self._model_uri 158 159 def run_inference( 160 self, 161 batch: Sequence[numpy.ndarray], 162 model: tf.Module, 163 inference_args: Optional[Dict[str, Any]] = None 164 ) -> Iterable[PredictionResult]: 165 """ 166 Runs inferences on a batch of numpy array and returns an Iterable of 167 numpy array Predictions. 168 169 This method stacks the n-dimensional numpy array in a vectorized format to 170 optimize the inference call. 171 172 Args: 173 batch: A sequence of numpy nd-array. These should be batchable, as this 174 method will call `numpy.stack()` and pass in batched numpy nd-array 175 with dimensions (batch_size, n_features, etc.) into the model's 176 predict() function. 177 model: A Tensorflow model. 178 inference_args: any additional arguments for an inference. 179 180 Returns: 181 An Iterable of type PredictionResult. 182 """ 183 inference_args = {} if not inference_args else inference_args 184 return self._inference_fn(model, batch, inference_args, self._model_uri) 185 186 def get_num_bytes(self, batch: Sequence[numpy.ndarray]) -> int: 187 """ 188 Returns: 189 The number of bytes of data for a batch of numpy arrays. 190 """ 191 return sum(sys.getsizeof(element) for element in batch) 192 193 def get_metrics_namespace(self) -> str: 194 """ 195 Returns: 196 A namespace for metrics collected by the RunInference transform. 197 """ 198 return 'BeamML_TF_Numpy' 199 200 def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]): 201 pass 202 203 def batch_elements_kwargs(self): 204 return self._batching_kwargs 205 206 207 class TFModelHandlerTensor(ModelHandler[tf.Tensor, PredictionResult, 208 tf.Module]): 209 def __init__( 210 self, 211 model_uri: str, 212 model_type: ModelType = ModelType.SAVED_MODEL, 213 create_model_fn: Optional[Callable] = None, 214 *, 215 load_model_args: Optional[Dict[str, Any]] = None, 216 custom_weights: str = "", 217 inference_fn: TensorInferenceFn = default_tensor_inference_fn, 218 min_batch_size: Optional[int] = None, 219 max_batch_size: Optional[int] = None, 220 **kwargs): 221 """Implementation of the ModelHandler interface for Tensorflow. 222 223 Example Usage:: 224 225 pcoll | RunInference(TFModelHandlerTensor(model_uri="my_uri")) 226 227 See https://www.tensorflow.org/tutorials/keras/save_and_load for details. 228 229 Args: 230 model_uri (str): path to the trained model. 231 model_type: type of model to be loaded. 232 Defaults to SAVED_MODEL. 233 create_model_fn: a function that creates and returns a new 234 tensorflow model to load the saved weights. 235 It should be used with ModelType.SAVED_WEIGHTS. 236 load_model_args: a dictionary of parameters to pass to the load_model 237 function of TensorFlow to specify custom config. 238 custom_weights (str): path to the custom weights to be applied 239 once the model is loaded. 240 inference_fn: inference function to use during RunInference. 241 Defaults to default_numpy_inference_fn. 242 kwargs: 'env_vars' can be used to set environment variables 243 before loading the model. 244 245 **Supported Versions:** RunInference APIs in Apache Beam have been tested 246 with Tensorflow 2.11. 247 """ 248 self._model_uri = model_uri 249 self._model_type = model_type 250 self._inference_fn = inference_fn 251 self._create_model_fn = create_model_fn 252 self._env_vars = kwargs.get('env_vars', {}) 253 self._load_model_args = {} if not load_model_args else load_model_args 254 self._custom_weights = custom_weights 255 self._batching_kwargs = {} 256 if min_batch_size is not None: 257 self._batching_kwargs['min_batch_size'] = min_batch_size 258 if max_batch_size is not None: 259 self._batching_kwargs['max_batch_size'] = max_batch_size 260 261 def load_model(self) -> tf.Module: 262 """Loads and initializes a tensorflow model for processing.""" 263 if self._model_type == ModelType.SAVED_WEIGHTS: 264 if not self._create_model_fn: 265 raise ValueError( 266 "Callable create_model_fn must be passed" 267 "with ModelType.SAVED_WEIGHTS") 268 return _load_model_from_weights(self._create_model_fn, self._model_uri) 269 return _load_model( 270 self._model_uri, self._custom_weights, self._load_model_args) 271 272 def update_model_path(self, model_path: Optional[str] = None): 273 self._model_uri = model_path if model_path else self._model_uri 274 275 def run_inference( 276 self, 277 batch: Sequence[tf.Tensor], 278 model: tf.Module, 279 inference_args: Optional[Dict[str, Any]] = None 280 ) -> Iterable[PredictionResult]: 281 """ 282 Runs inferences on a batch of tf.Tensor and returns an Iterable of 283 Tensor Predictions. 284 285 This method stacks the list of Tensors in a vectorized format to optimize 286 the inference call. 287 288 Args: 289 batch: A sequence of Tensors. These Tensors should be batchable, as this 290 method will call `tf.stack()` and pass in batched Tensors with 291 dimensions (batch_size, n_features, etc.) into the model's predict() 292 function. 293 model: A Tensorflow model. 294 inference_args: Non-batchable arguments required as inputs to the model's 295 forward() function. Unlike Tensors in `batch`, these parameters will 296 not be dynamically batched 297 Returns: 298 An Iterable of type PredictionResult. 299 """ 300 inference_args = {} if not inference_args else inference_args 301 return self._inference_fn(model, batch, inference_args, self._model_uri) 302 303 def get_num_bytes(self, batch: Sequence[tf.Tensor]) -> int: 304 """ 305 Returns: 306 The number of bytes of data for a batch of Tensors. 307 """ 308 return sum(sys.getsizeof(element) for element in batch) 309 310 def get_metrics_namespace(self) -> str: 311 """ 312 Returns: 313 A namespace for metrics collected by the RunInference transform. 314 """ 315 return 'BeamML_TF_Tensor' 316 317 def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]): 318 pass 319 320 def batch_elements_kwargs(self): 321 return self._batching_kwargs