github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/ml/inference/xgboost_inference.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  import sys
    19  from abc import ABC
    20  from typing import Any
    21  from typing import Callable
    22  from typing import Dict
    23  from typing import Iterable
    24  from typing import Optional
    25  from typing import Sequence
    26  from typing import Union
    27  
    28  import numpy
    29  import pandas
    30  import scipy
    31  
    32  import datatable
    33  import xgboost
    34  from apache_beam.io.filesystems import FileSystems
    35  from apache_beam.ml.inference.base import ExampleT
    36  from apache_beam.ml.inference.base import ModelHandler
    37  from apache_beam.ml.inference.base import ModelT
    38  from apache_beam.ml.inference.base import PredictionResult
    39  from apache_beam.ml.inference.base import PredictionT
    40  
    41  __all__ = [
    42      'XGBoostModelHandler',
    43      'XGBoostModelHandlerNumpy',
    44      'XGBoostModelHandlerPandas',
    45      'XGBoostModelHandlerSciPy',
    46      'XGBoostModelHandlerDatatable'
    47  ]
    48  
    49  XGBoostInferenceFn = Callable[[
    50      Sequence[object],
    51      Union[xgboost.Booster, xgboost.XGBModel],
    52      Optional[Dict[str, Any]]
    53  ],
    54                                Iterable[PredictionResult]]
    55  
    56  
    57  def default_xgboost_inference_fn(
    58      batch: Sequence[object],
    59      model: Union[xgboost.Booster, xgboost.XGBModel],
    60      inference_args: Optional[Dict[str,
    61                                    Any]] = None) -> Iterable[PredictionResult]:
    62    inference_args = {} if not inference_args else inference_args
    63  
    64    if type(model) == xgboost.Booster:
    65      batch = [xgboost.DMatrix(array) for array in batch]
    66    predictions = [model.predict(el, **inference_args) for el in batch]
    67  
    68    return [PredictionResult(x, y) for x, y in zip(batch, predictions)]
    69  
    70  
    71  class XGBoostModelHandler(ModelHandler[ExampleT, PredictionT, ModelT], ABC):
    72    def __init__(
    73        self,
    74        model_class: Union[Callable[..., xgboost.Booster],
    75                           Callable[..., xgboost.XGBModel]],
    76        model_state: str,
    77        inference_fn: XGBoostInferenceFn = default_xgboost_inference_fn,
    78        **kwargs):
    79      """Implementation of the ModelHandler interface for XGBoost.
    80  
    81      Example Usage::
    82  
    83          pcoll | RunInference(
    84                      XGBoostModelHandler(
    85                          model_class="XGBoost Model Class",
    86                          model_state="my_model_state.json")))
    87  
    88      See https://xgboost.readthedocs.io/en/stable/tutorials/saving_model.html
    89      for details
    90  
    91      Args:
    92        model_class: class of the XGBoost model that defines the model
    93          structure.
    94        model_state: path to a json file that contains the model's
    95          configuration.
    96        inference_fn: the inference function to use during RunInference.
    97          default=default_xgboost_inference_fn
    98        kwargs: 'env_vars' can be used to set environment variables
    99          before loading the model.
   100  
   101      **Supported Versions:** RunInference APIs in Apache Beam have been tested
   102      with XGBoost 1.6.0 and 1.7.0
   103  
   104      XGBoost 1.0.0 introduced support for using JSON to save and load
   105      XGBoost models. XGBoost 1.6.0, additional support for Universal Binary JSON.
   106      It is recommended to use a model trained in XGBoost 1.6.0 or higher.
   107      While you should be able to load models created in older versions, there
   108      are no guarantees this will work as expected.
   109  
   110      This class is the superclass of all the various XGBoostModelhandlers
   111      and should not be instantiated directly. (See instead
   112      XGBoostModelHandlerNumpy, XGBoostModelHandlerPandas, etc.)
   113      """
   114      self._model_class = model_class
   115      self._model_state = model_state
   116      self._inference_fn = inference_fn
   117      self._env_vars = kwargs.get('env_vars', {})
   118  
   119    def load_model(self) -> Union[xgboost.Booster, xgboost.XGBModel]:
   120      model = self._model_class()
   121      model_state_file_handler = FileSystems.open(self._model_state, 'rb')
   122      model_state_bytes = model_state_file_handler.read()
   123      # Convert into a bytearray so that the
   124      # model state can be loaded in XGBoost
   125      model_state_bytearray = bytearray(model_state_bytes)
   126      model.load_model(model_state_bytearray)
   127      return model
   128  
   129    def get_metrics_namespace(self) -> str:
   130      return 'BeamML_XGBoost'
   131  
   132  
   133  class XGBoostModelHandlerNumpy(XGBoostModelHandler[numpy.ndarray,
   134                                                     PredictionResult,
   135                                                     Union[xgboost.Booster,
   136                                                           xgboost.XGBModel]]):
   137    """Implementation of the ModelHandler interface for XGBoost
   138    using numpy arrays as input.
   139  
   140    Example Usage::
   141  
   142        pcoll | RunInference(
   143                    XGBoostModelHandlerNumpy(
   144                        model_class="XGBoost Model Class",
   145                        model_state="my_model_state.json")))
   146  
   147    Args:
   148      model_class: class of the XGBoost model that defines the model
   149        structure.
   150      model_state: path to a json file that contains the model's
   151        configuration.
   152      inference_fn: the inference function to use during RunInference.
   153        default=default_xgboost_inference_fn
   154    """
   155    def run_inference(
   156        self,
   157        batch: Sequence[numpy.ndarray],
   158        model: Union[xgboost.Booster, xgboost.XGBModel],
   159        inference_args: Optional[Dict[str, Any]] = None
   160    ) -> Iterable[PredictionResult]:
   161      """Runs inferences on a batch of 2d numpy arrays.
   162  
   163      Args:
   164        batch: A sequence of examples as 2d numpy arrays. Each
   165          row in an array is a single example. The dimensions
   166          must match the dimensions of the data used to train
   167          the model.
   168        model: XGBoost booster or XBGModel (sklearn interface). Must
   169          implement predict(X). Where the parameter X is a 2d numpy array.
   170        inference_args: Any additional arguments for an inference.
   171  
   172      Returns:
   173        An Iterable of type PredictionResult.
   174      """
   175      return self._inference_fn(batch, model, inference_args)
   176  
   177    def get_num_bytes(self, batch: Sequence[numpy.ndarray]) -> int:
   178      """
   179      Returns:
   180        The number of bytes of data for a batch.
   181      """
   182      return sum(sys.getsizeof(element) for element in batch)
   183  
   184  
   185  class XGBoostModelHandlerPandas(XGBoostModelHandler[pandas.DataFrame,
   186                                                      PredictionResult,
   187                                                      Union[xgboost.Booster,
   188                                                            xgboost.XGBModel]]):
   189    """Implementation of the ModelHandler interface for XGBoost
   190    using pandas dataframes as input.
   191  
   192    Example Usage::
   193  
   194        pcoll | RunInference(
   195                    XGBoostModelHandlerPandas(
   196                        model_class="XGBoost Model Class",
   197                        model_state="my_model_state.json")))
   198  
   199    Args:
   200      model_class: class of the XGBoost model that defines the model
   201        structure.
   202      model_state: path to a json file that contains the model's
   203        configuration.
   204      inference_fn: the inference function to use during RunInference.
   205        default=default_xgboost_inference_fn
   206    """
   207    def run_inference(
   208        self,
   209        batch: Sequence[pandas.DataFrame],
   210        model: Union[xgboost.Booster, xgboost.XGBModel],
   211        inference_args: Optional[Dict[str, Any]] = None
   212    ) -> Iterable[PredictionResult]:
   213      """Runs inferences on a batch of pandas dataframes.
   214  
   215      Args:
   216        batch: A sequence of examples as pandas dataframes. Each
   217          row in a dataframe is a single example. The dimensions
   218          must match the dimensions of the data used to train
   219          the model.
   220        model: XGBoost booster or XBGModel (sklearn interface). Must
   221          implement predict(X). Where the parameter X is a pandas dataframe.
   222        inference_args: Any additional arguments for an inference.
   223  
   224      Returns:
   225        An Iterable of type PredictionResult.
   226      """
   227      return self._inference_fn(batch, model, inference_args)
   228  
   229    def get_num_bytes(self, batch: Sequence[pandas.DataFrame]) -> int:
   230      """
   231      Returns:
   232          The number of bytes of data for a batch of Numpy arrays.
   233      """
   234      return sum(df.memory_usage(deep=True).sum() for df in batch)
   235  
   236  
   237  class XGBoostModelHandlerSciPy(XGBoostModelHandler[scipy.sparse.csr_matrix,
   238                                                     PredictionResult,
   239                                                     Union[xgboost.Booster,
   240                                                           xgboost.XGBModel]]):
   241    """ Implementation of the ModelHandler interface for XGBoost
   242    using scipy matrices as input.
   243  
   244    Example Usage::
   245  
   246        pcoll | RunInference(
   247                    XGBoostModelHandlerSciPy(
   248                        model_class="XGBoost Model Class",
   249                        model_state="my_model_state.json")))
   250  
   251    Args:
   252      model_class: class of the XGBoost model that defines the model
   253        structure.
   254      model_state: path to a json file that contains the model's
   255        configuration.
   256      inference_fn: the inference function to use during RunInference.
   257        default=default_xgboost_inference_fn
   258    """
   259    def run_inference(
   260        self,
   261        batch: Sequence[scipy.sparse.csr_matrix],
   262        model: Union[xgboost.Booster, xgboost.XGBModel],
   263        inference_args: Optional[Dict[str, Any]] = None
   264    ) -> Iterable[PredictionResult]:
   265      """Runs inferences on a batch of SciPy sparse matrices.
   266  
   267      Args:
   268        batch: A sequence of examples as Scipy sparse matrices.
   269         The dimensions must match the dimensions of the data
   270         used to train the model.
   271        model: XGBoost booster or XBGModel (sklearn interface). Must implement
   272          predict(X). Where the parameter X is a SciPy sparse matrix.
   273        inference_args: Any additional arguments for an inference.
   274  
   275      Returns:
   276        An Iterable of type PredictionResult.
   277      """
   278      return self._inference_fn(batch, model, inference_args)
   279  
   280    def get_num_bytes(self, batch: Sequence[scipy.sparse.csr_matrix]) -> int:
   281      """
   282      Returns:
   283        The number of bytes of data for a batch.
   284      """
   285      return sum(sys.getsizeof(element) for element in batch)
   286  
   287  
   288  class XGBoostModelHandlerDatatable(XGBoostModelHandler[datatable.Frame,
   289                                                         PredictionResult,
   290                                                         Union[xgboost.Booster,
   291                                                               xgboost.XGBModel]]
   292                                     ):
   293    """Implementation of the ModelHandler interface for XGBoost
   294    using datatable dataframes as input.
   295  
   296    Example Usage::
   297  
   298        pcoll | RunInference(
   299                    XGBoostModelHandlerDatatable(
   300                        model_class="XGBoost Model Class",
   301                        model_state="my_model_state.json")))
   302  
   303    Args:
   304      model_class: class of the XGBoost model that defines the model
   305        structure.
   306      model_state: path to a json file that contains the model's
   307        configuration.
   308      inference_fn: the inference function to use during RunInference.
   309        default=default_xgboost_inference_fn
   310    """
   311    def run_inference(
   312        self,
   313        batch: Sequence[datatable.Frame],
   314        model: Union[xgboost.Booster, xgboost.XGBModel],
   315        inference_args: Optional[Dict[str, Any]] = None
   316    ) -> Iterable[PredictionResult]:
   317      """Runs inferences on a batch of datatable dataframe.
   318  
   319      Args:
   320        batch: A sequence of examples as datatable dataframes. Each
   321          row in a dataframe is a single example. The dimensions
   322          must match the dimensions of the data used to train
   323          the model.
   324        model: XGBoost booster or XBGModel (sklearn interface). Must implement
   325          predict(X). Where the parameter X is a datatable dataframe.
   326        inference_args: Any additional arguments for an inference.
   327  
   328      Returns:
   329        An Iterable of type PredictionResult.
   330      """
   331      return self._inference_fn(batch, model, inference_args)
   332  
   333    def get_num_bytes(self, batch: Sequence[datatable.Frame]) -> int:
   334      """
   335      Returns:
   336        The number of bytes of data for a batch.
   337      """
   338      return sum(sys.getsizeof(element) for element in batch)