github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/ml/inference/sklearn_inference.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  import enum
    19  import pickle
    20  import sys
    21  from typing import Any
    22  from typing import Callable
    23  from typing import Dict
    24  from typing import Iterable
    25  from typing import Optional
    26  from typing import Sequence
    27  
    28  import numpy
    29  import pandas
    30  from sklearn.base import BaseEstimator
    31  
    32  from apache_beam.io.filesystems import FileSystems
    33  from apache_beam.ml.inference import utils
    34  from apache_beam.ml.inference.base import ModelHandler
    35  from apache_beam.ml.inference.base import PredictionResult
    36  
    37  try:
    38    import joblib
    39  except ImportError:
    40    # joblib is an optional dependency.
    41    pass
    42  
    43  __all__ = [
    44      'SklearnModelHandlerNumpy',
    45      'SklearnModelHandlerPandas',
    46  ]
    47  
    48  NumpyInferenceFn = Callable[
    49      [BaseEstimator, Sequence[numpy.ndarray], Optional[Dict[str, Any]]], Any]
    50  
    51  
    52  class ModelFileType(enum.Enum):
    53    """Defines how a model file is serialized. Options are pickle or joblib."""
    54    PICKLE = 1
    55    JOBLIB = 2
    56  
    57  
    58  def _load_model(model_uri, file_type):
    59    file = FileSystems.open(model_uri, 'rb')
    60    if file_type == ModelFileType.PICKLE:
    61      return pickle.load(file)
    62    elif file_type == ModelFileType.JOBLIB:
    63      if not joblib:
    64        raise ImportError(
    65            'Could not import joblib in this execution environment. '
    66            'For help with managing dependencies on Python workers.'
    67            'see https://beam.apache.org/documentation/sdks/python-pipeline-dependencies/'  # pylint: disable=line-too-long
    68        )
    69      return joblib.load(file)
    70    raise AssertionError('Unsupported serialization type.')
    71  
    72  
    73  def _default_numpy_inference_fn(
    74      model: BaseEstimator,
    75      batch: Sequence[numpy.ndarray],
    76      inference_args: Optional[Dict[str, Any]] = None) -> Any:
    77    # vectorize data for better performance
    78    vectorized_batch = numpy.stack(batch, axis=0)
    79    return model.predict(vectorized_batch)
    80  
    81  
    82  class SklearnModelHandlerNumpy(ModelHandler[numpy.ndarray,
    83                                              PredictionResult,
    84                                              BaseEstimator]):
    85    def __init__(
    86        self,
    87        model_uri: str,
    88        model_file_type: ModelFileType = ModelFileType.PICKLE,
    89        *,
    90        inference_fn: NumpyInferenceFn = _default_numpy_inference_fn,
    91        min_batch_size: Optional[int] = None,
    92        max_batch_size: Optional[int] = None,
    93        **kwargs):
    94      """ Implementation of the ModelHandler interface for scikit-learn
    95      using numpy arrays as input.
    96  
    97      Example Usage::
    98  
    99        pcoll | RunInference(SklearnModelHandlerNumpy(model_uri="my_uri"))
   100  
   101      Args:
   102        model_uri: The URI to where the model is saved.
   103        model_file_type: The method of serialization of the argument.
   104          default=pickle
   105        inference_fn: The inference function to use.
   106          default=_default_numpy_inference_fn
   107        min_batch_size: the minimum batch size to use when batching inputs. This
   108          batch will be fed into the inference_fn as a Sequence of Numpy
   109          ndarrays.
   110        max_batch_size: the maximum batch size to use when batching inputs. This
   111          batch will be fed into the inference_fn as a Sequence of Numpy
   112          ndarrays.
   113        kwargs: 'env_vars' can be used to set environment variables
   114          before loading the model.
   115      """
   116      self._model_uri = model_uri
   117      self._model_file_type = model_file_type
   118      self._model_inference_fn = inference_fn
   119      self._batching_kwargs = {}
   120      if min_batch_size is not None:
   121        self._batching_kwargs['min_batch_size'] = min_batch_size
   122      if max_batch_size is not None:
   123        self._batching_kwargs['max_batch_size'] = max_batch_size
   124      self._env_vars = kwargs.get('env_vars', {})
   125  
   126    def load_model(self) -> BaseEstimator:
   127      """Loads and initializes a model for processing."""
   128      return _load_model(self._model_uri, self._model_file_type)
   129  
   130    def update_model_path(self, model_path: Optional[str] = None):
   131      self._model_uri = model_path if model_path else self._model_uri
   132  
   133    def run_inference(
   134        self,
   135        batch: Sequence[numpy.ndarray],
   136        model: BaseEstimator,
   137        inference_args: Optional[Dict[str, Any]] = None
   138    ) -> Iterable[PredictionResult]:
   139      """Runs inferences on a batch of numpy arrays.
   140  
   141      Args:
   142        batch: A sequence of examples as numpy arrays. They should
   143          be single examples.
   144        model: A numpy model or pipeline. Must implement predict(X).
   145          Where the parameter X is a numpy array.
   146        inference_args: Any additional arguments for an inference.
   147  
   148      Returns:
   149        An Iterable of type PredictionResult.
   150      """
   151      predictions = self._model_inference_fn(
   152          model,
   153          batch,
   154          inference_args,
   155      )
   156  
   157      return utils._convert_to_result(
   158          batch, predictions, model_id=self._model_uri)
   159  
   160    def get_num_bytes(self, batch: Sequence[numpy.ndarray]) -> int:
   161      """
   162      Returns:
   163        The number of bytes of data for a batch.
   164      """
   165      return sum(sys.getsizeof(element) for element in batch)
   166  
   167    def get_metrics_namespace(self) -> str:
   168      """
   169      Returns:
   170         A namespace for metrics collected by the RunInference transform.
   171      """
   172      return 'BeamML_Sklearn'
   173  
   174    def batch_elements_kwargs(self):
   175      return self._batching_kwargs
   176  
   177  
   178  PandasInferenceFn = Callable[
   179      [BaseEstimator, Sequence[pandas.DataFrame], Optional[Dict[str, Any]]], Any]
   180  
   181  
   182  def _default_pandas_inference_fn(
   183      model: BaseEstimator,
   184      batch: Sequence[pandas.DataFrame],
   185      inference_args: Optional[Dict[str, Any]] = None) -> Any:
   186    # vectorize data for better performance
   187    vectorized_batch = pandas.concat(batch, axis=0)
   188    predictions = model.predict(vectorized_batch)
   189    splits = [
   190        vectorized_batch.iloc[[i]] for i in range(vectorized_batch.shape[0])
   191    ]
   192    return predictions, splits
   193  
   194  
   195  class SklearnModelHandlerPandas(ModelHandler[pandas.DataFrame,
   196                                               PredictionResult,
   197                                               BaseEstimator]):
   198    def __init__(
   199        self,
   200        model_uri: str,
   201        model_file_type: ModelFileType = ModelFileType.PICKLE,
   202        *,
   203        inference_fn: PandasInferenceFn = _default_pandas_inference_fn,
   204        min_batch_size: Optional[int] = None,
   205        max_batch_size: Optional[int] = None,
   206        **kwargs):
   207      """Implementation of the ModelHandler interface for scikit-learn that
   208      supports pandas dataframes.
   209  
   210      Example Usage::
   211  
   212        pcoll | RunInference(SklearnModelHandlerPandas(model_uri="my_uri"))
   213  
   214      **NOTE:** This API and its implementation are under development and
   215      do not provide backward compatibility guarantees.
   216  
   217      Args:
   218        model_uri: The URI to where the model is saved.
   219        model_file_type: The method of serialization of the argument.
   220          default=pickle
   221        inference_fn: The inference function to use.
   222          default=_default_pandas_inference_fn
   223        min_batch_size: the minimum batch size to use when batching inputs. This
   224          batch will be fed into the inference_fn as a Sequence of Pandas
   225          Dataframes.
   226        max_batch_size: the maximum batch size to use when batching inputs. This
   227          batch will be fed into the inference_fn as a Sequence of Pandas
   228          Dataframes.
   229        kwargs: 'env_vars' can be used to set environment variables
   230          before loading the model.
   231      """
   232      self._model_uri = model_uri
   233      self._model_file_type = model_file_type
   234      self._model_inference_fn = inference_fn
   235      self._batching_kwargs = {}
   236      if min_batch_size is not None:
   237        self._batching_kwargs['min_batch_size'] = min_batch_size
   238      if max_batch_size is not None:
   239        self._batching_kwargs['max_batch_size'] = max_batch_size
   240      self._env_vars = kwargs.get('env_vars', {})
   241  
   242    def load_model(self) -> BaseEstimator:
   243      """Loads and initializes a model for processing."""
   244      return _load_model(self._model_uri, self._model_file_type)
   245  
   246    def update_model_path(self, model_path: Optional[str] = None):
   247      self._model_uri = model_path if model_path else self._model_uri
   248  
   249    def run_inference(
   250        self,
   251        batch: Sequence[pandas.DataFrame],
   252        model: BaseEstimator,
   253        inference_args: Optional[Dict[str, Any]] = None
   254    ) -> Iterable[PredictionResult]:
   255      """
   256      Runs inferences on a batch of pandas dataframes.
   257  
   258      Args:
   259        batch: A sequence of examples as numpy arrays. They should
   260          be single examples.
   261        model: A dataframe model or pipeline. Must implement predict(X).
   262          Where the parameter X is a pandas dataframe.
   263        inference_args: Any additional arguments for an inference.
   264  
   265      Returns:
   266        An Iterable of type PredictionResult.
   267      """
   268      # sklearn_inference currently only supports single rowed dataframes.
   269      for dataframe in iter(batch):
   270        if dataframe.shape[0] != 1:
   271          raise ValueError('Only dataframes with single rows are supported.')
   272  
   273      predictions, splits = self._model_inference_fn(model, batch, inference_args)
   274  
   275      return utils._convert_to_result(
   276          splits, predictions, model_id=self._model_uri)
   277  
   278    def get_num_bytes(self, batch: Sequence[pandas.DataFrame]) -> int:
   279      """
   280      Returns:
   281        The number of bytes of data for a batch.
   282      """
   283      return sum(df.memory_usage(deep=True).sum() for df in batch)
   284  
   285    def get_metrics_namespace(self) -> str:
   286      """
   287      Returns:
   288         A namespace for metrics collected by the RunInference transform.
   289      """
   290      return 'BeamML_Sklearn'
   291  
   292    def batch_elements_kwargs(self):
   293      return self._batching_kwargs