github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/ml/inference/sklearn_inference.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 import enum 19 import pickle 20 import sys 21 from typing import Any 22 from typing import Callable 23 from typing import Dict 24 from typing import Iterable 25 from typing import Optional 26 from typing import Sequence 27 28 import numpy 29 import pandas 30 from sklearn.base import BaseEstimator 31 32 from apache_beam.io.filesystems import FileSystems 33 from apache_beam.ml.inference import utils 34 from apache_beam.ml.inference.base import ModelHandler 35 from apache_beam.ml.inference.base import PredictionResult 36 37 try: 38 import joblib 39 except ImportError: 40 # joblib is an optional dependency. 41 pass 42 43 __all__ = [ 44 'SklearnModelHandlerNumpy', 45 'SklearnModelHandlerPandas', 46 ] 47 48 NumpyInferenceFn = Callable[ 49 [BaseEstimator, Sequence[numpy.ndarray], Optional[Dict[str, Any]]], Any] 50 51 52 class ModelFileType(enum.Enum): 53 """Defines how a model file is serialized. Options are pickle or joblib.""" 54 PICKLE = 1 55 JOBLIB = 2 56 57 58 def _load_model(model_uri, file_type): 59 file = FileSystems.open(model_uri, 'rb') 60 if file_type == ModelFileType.PICKLE: 61 return pickle.load(file) 62 elif file_type == ModelFileType.JOBLIB: 63 if not joblib: 64 raise ImportError( 65 'Could not import joblib in this execution environment. ' 66 'For help with managing dependencies on Python workers.' 67 'see https://beam.apache.org/documentation/sdks/python-pipeline-dependencies/' # pylint: disable=line-too-long 68 ) 69 return joblib.load(file) 70 raise AssertionError('Unsupported serialization type.') 71 72 73 def _default_numpy_inference_fn( 74 model: BaseEstimator, 75 batch: Sequence[numpy.ndarray], 76 inference_args: Optional[Dict[str, Any]] = None) -> Any: 77 # vectorize data for better performance 78 vectorized_batch = numpy.stack(batch, axis=0) 79 return model.predict(vectorized_batch) 80 81 82 class SklearnModelHandlerNumpy(ModelHandler[numpy.ndarray, 83 PredictionResult, 84 BaseEstimator]): 85 def __init__( 86 self, 87 model_uri: str, 88 model_file_type: ModelFileType = ModelFileType.PICKLE, 89 *, 90 inference_fn: NumpyInferenceFn = _default_numpy_inference_fn, 91 min_batch_size: Optional[int] = None, 92 max_batch_size: Optional[int] = None, 93 **kwargs): 94 """ Implementation of the ModelHandler interface for scikit-learn 95 using numpy arrays as input. 96 97 Example Usage:: 98 99 pcoll | RunInference(SklearnModelHandlerNumpy(model_uri="my_uri")) 100 101 Args: 102 model_uri: The URI to where the model is saved. 103 model_file_type: The method of serialization of the argument. 104 default=pickle 105 inference_fn: The inference function to use. 106 default=_default_numpy_inference_fn 107 min_batch_size: the minimum batch size to use when batching inputs. This 108 batch will be fed into the inference_fn as a Sequence of Numpy 109 ndarrays. 110 max_batch_size: the maximum batch size to use when batching inputs. This 111 batch will be fed into the inference_fn as a Sequence of Numpy 112 ndarrays. 113 kwargs: 'env_vars' can be used to set environment variables 114 before loading the model. 115 """ 116 self._model_uri = model_uri 117 self._model_file_type = model_file_type 118 self._model_inference_fn = inference_fn 119 self._batching_kwargs = {} 120 if min_batch_size is not None: 121 self._batching_kwargs['min_batch_size'] = min_batch_size 122 if max_batch_size is not None: 123 self._batching_kwargs['max_batch_size'] = max_batch_size 124 self._env_vars = kwargs.get('env_vars', {}) 125 126 def load_model(self) -> BaseEstimator: 127 """Loads and initializes a model for processing.""" 128 return _load_model(self._model_uri, self._model_file_type) 129 130 def update_model_path(self, model_path: Optional[str] = None): 131 self._model_uri = model_path if model_path else self._model_uri 132 133 def run_inference( 134 self, 135 batch: Sequence[numpy.ndarray], 136 model: BaseEstimator, 137 inference_args: Optional[Dict[str, Any]] = None 138 ) -> Iterable[PredictionResult]: 139 """Runs inferences on a batch of numpy arrays. 140 141 Args: 142 batch: A sequence of examples as numpy arrays. They should 143 be single examples. 144 model: A numpy model or pipeline. Must implement predict(X). 145 Where the parameter X is a numpy array. 146 inference_args: Any additional arguments for an inference. 147 148 Returns: 149 An Iterable of type PredictionResult. 150 """ 151 predictions = self._model_inference_fn( 152 model, 153 batch, 154 inference_args, 155 ) 156 157 return utils._convert_to_result( 158 batch, predictions, model_id=self._model_uri) 159 160 def get_num_bytes(self, batch: Sequence[numpy.ndarray]) -> int: 161 """ 162 Returns: 163 The number of bytes of data for a batch. 164 """ 165 return sum(sys.getsizeof(element) for element in batch) 166 167 def get_metrics_namespace(self) -> str: 168 """ 169 Returns: 170 A namespace for metrics collected by the RunInference transform. 171 """ 172 return 'BeamML_Sklearn' 173 174 def batch_elements_kwargs(self): 175 return self._batching_kwargs 176 177 178 PandasInferenceFn = Callable[ 179 [BaseEstimator, Sequence[pandas.DataFrame], Optional[Dict[str, Any]]], Any] 180 181 182 def _default_pandas_inference_fn( 183 model: BaseEstimator, 184 batch: Sequence[pandas.DataFrame], 185 inference_args: Optional[Dict[str, Any]] = None) -> Any: 186 # vectorize data for better performance 187 vectorized_batch = pandas.concat(batch, axis=0) 188 predictions = model.predict(vectorized_batch) 189 splits = [ 190 vectorized_batch.iloc[[i]] for i in range(vectorized_batch.shape[0]) 191 ] 192 return predictions, splits 193 194 195 class SklearnModelHandlerPandas(ModelHandler[pandas.DataFrame, 196 PredictionResult, 197 BaseEstimator]): 198 def __init__( 199 self, 200 model_uri: str, 201 model_file_type: ModelFileType = ModelFileType.PICKLE, 202 *, 203 inference_fn: PandasInferenceFn = _default_pandas_inference_fn, 204 min_batch_size: Optional[int] = None, 205 max_batch_size: Optional[int] = None, 206 **kwargs): 207 """Implementation of the ModelHandler interface for scikit-learn that 208 supports pandas dataframes. 209 210 Example Usage:: 211 212 pcoll | RunInference(SklearnModelHandlerPandas(model_uri="my_uri")) 213 214 **NOTE:** This API and its implementation are under development and 215 do not provide backward compatibility guarantees. 216 217 Args: 218 model_uri: The URI to where the model is saved. 219 model_file_type: The method of serialization of the argument. 220 default=pickle 221 inference_fn: The inference function to use. 222 default=_default_pandas_inference_fn 223 min_batch_size: the minimum batch size to use when batching inputs. This 224 batch will be fed into the inference_fn as a Sequence of Pandas 225 Dataframes. 226 max_batch_size: the maximum batch size to use when batching inputs. This 227 batch will be fed into the inference_fn as a Sequence of Pandas 228 Dataframes. 229 kwargs: 'env_vars' can be used to set environment variables 230 before loading the model. 231 """ 232 self._model_uri = model_uri 233 self._model_file_type = model_file_type 234 self._model_inference_fn = inference_fn 235 self._batching_kwargs = {} 236 if min_batch_size is not None: 237 self._batching_kwargs['min_batch_size'] = min_batch_size 238 if max_batch_size is not None: 239 self._batching_kwargs['max_batch_size'] = max_batch_size 240 self._env_vars = kwargs.get('env_vars', {}) 241 242 def load_model(self) -> BaseEstimator: 243 """Loads and initializes a model for processing.""" 244 return _load_model(self._model_uri, self._model_file_type) 245 246 def update_model_path(self, model_path: Optional[str] = None): 247 self._model_uri = model_path if model_path else self._model_uri 248 249 def run_inference( 250 self, 251 batch: Sequence[pandas.DataFrame], 252 model: BaseEstimator, 253 inference_args: Optional[Dict[str, Any]] = None 254 ) -> Iterable[PredictionResult]: 255 """ 256 Runs inferences on a batch of pandas dataframes. 257 258 Args: 259 batch: A sequence of examples as numpy arrays. They should 260 be single examples. 261 model: A dataframe model or pipeline. Must implement predict(X). 262 Where the parameter X is a pandas dataframe. 263 inference_args: Any additional arguments for an inference. 264 265 Returns: 266 An Iterable of type PredictionResult. 267 """ 268 # sklearn_inference currently only supports single rowed dataframes. 269 for dataframe in iter(batch): 270 if dataframe.shape[0] != 1: 271 raise ValueError('Only dataframes with single rows are supported.') 272 273 predictions, splits = self._model_inference_fn(model, batch, inference_args) 274 275 return utils._convert_to_result( 276 splits, predictions, model_id=self._model_uri) 277 278 def get_num_bytes(self, batch: Sequence[pandas.DataFrame]) -> int: 279 """ 280 Returns: 281 The number of bytes of data for a batch. 282 """ 283 return sum(df.memory_usage(deep=True).sum() for df in batch) 284 285 def get_metrics_namespace(self) -> str: 286 """ 287 Returns: 288 A namespace for metrics collected by the RunInference transform. 289 """ 290 return 'BeamML_Sklearn' 291 292 def batch_elements_kwargs(self): 293 return self._batching_kwargs