github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/ml/inference/base.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  # TODO: https://github.com/apache/beam/issues/21822
    18  # mypy: ignore-errors
    19  
    20  """An extensible run inference transform.
    21  
    22  Users of this module can extend the ModelHandler class for any machine learning
    23  framework. A ModelHandler implementation is a required parameter of
    24  RunInference.
    25  
    26  The transform handles standard inference functionality, like metric
    27  collection, sharing model between threads, and batching elements.
    28  """
    29  
    30  import logging
    31  import os
    32  import pickle
    33  import sys
    34  import threading
    35  import time
    36  import uuid
    37  from typing import Any
    38  from typing import Callable
    39  from typing import Dict
    40  from typing import Generic
    41  from typing import Iterable
    42  from typing import Mapping
    43  from typing import NamedTuple
    44  from typing import Optional
    45  from typing import Sequence
    46  from typing import Tuple
    47  from typing import TypeVar
    48  from typing import Union
    49  
    50  import apache_beam as beam
    51  from apache_beam.utils import multi_process_shared
    52  from apache_beam.utils import shared
    53  
    54  try:
    55    # pylint: disable=wrong-import-order, wrong-import-position
    56    import resource
    57  except ImportError:
    58    resource = None  # type: ignore[assignment]
    59  
    60  _NANOSECOND_TO_MILLISECOND = 1_000_000
    61  _NANOSECOND_TO_MICROSECOND = 1_000
    62  
    63  ModelT = TypeVar('ModelT')
    64  ExampleT = TypeVar('ExampleT')
    65  PreProcessT = TypeVar('PreProcessT')
    66  PredictionT = TypeVar('PredictionT')
    67  PostProcessT = TypeVar('PostProcessT')
    68  _INPUT_TYPE = TypeVar('_INPUT_TYPE')
    69  _OUTPUT_TYPE = TypeVar('_OUTPUT_TYPE')
    70  KeyT = TypeVar('KeyT')
    71  
    72  
    73  # We use NamedTuple to define the structure of the PredictionResult,
    74  # however, as support for generic NamedTuples is not available in Python
    75  # versions prior to 3.11, we use the __new__ method to provide default
    76  # values for the fields while maintaining backwards compatibility.
    77  class PredictionResult(NamedTuple('PredictionResult',
    78                                    [('example', _INPUT_TYPE),
    79                                     ('inference', _OUTPUT_TYPE),
    80                                     ('model_id', Optional[str])])):
    81    __slots__ = ()
    82  
    83    def __new__(cls, example, inference, model_id=None):
    84      return super().__new__(cls, example, inference, model_id)
    85  
    86  
    87  PredictionResult.__doc__ = """A NamedTuple containing both input and output
    88    from the inference."""
    89  PredictionResult.example.__doc__ = """The input example."""
    90  PredictionResult.inference.__doc__ = """Results for the inference on the model
    91    for the given example."""
    92  PredictionResult.model_id.__doc__ = """Model ID used to run the prediction."""
    93  
    94  
    95  class ModelMetadata(NamedTuple):
    96    model_id: str
    97    model_name: str
    98  
    99  
   100  class RunInferenceDLQ(NamedTuple):
   101    failed_inferences: beam.PCollection
   102    failed_preprocessing: Sequence[beam.PCollection]
   103    failed_postprocessing: Sequence[beam.PCollection]
   104  
   105  
   106  ModelMetadata.model_id.__doc__ = """Unique identifier for the model. This can be
   107      a file path or a URL where the model can be accessed. It is used to load
   108      the model for inference."""
   109  ModelMetadata.model_name.__doc__ = """Human-readable name for the model. This
   110      can be used to identify the model in the metrics generated by the
   111      RunInference transform."""
   112  
   113  
   114  def _to_milliseconds(time_ns: int) -> int:
   115    return int(time_ns / _NANOSECOND_TO_MILLISECOND)
   116  
   117  
   118  def _to_microseconds(time_ns: int) -> int:
   119    return int(time_ns / _NANOSECOND_TO_MICROSECOND)
   120  
   121  
   122  class ModelHandler(Generic[ExampleT, PredictionT, ModelT]):
   123    """Has the ability to load and apply an ML model."""
   124    def __init__(self):
   125      """Environment variables are set using a dict named 'env_vars' before
   126      loading the model. Child classes can accept this dict as a kwarg."""
   127      self._env_vars = {}
   128  
   129    def load_model(self) -> ModelT:
   130      """Loads and initializes a model for processing."""
   131      raise NotImplementedError(type(self))
   132  
   133    def run_inference(
   134        self,
   135        batch: Sequence[ExampleT],
   136        model: ModelT,
   137        inference_args: Optional[Dict[str, Any]] = None) -> Iterable[PredictionT]:
   138      """Runs inferences on a batch of examples.
   139  
   140      Args:
   141        batch: A sequence of examples or features.
   142        model: The model used to make inferences.
   143        inference_args: Extra arguments for models whose inference call requires
   144          extra parameters.
   145  
   146      Returns:
   147        An Iterable of Predictions.
   148      """
   149      raise NotImplementedError(type(self))
   150  
   151    def get_num_bytes(self, batch: Sequence[ExampleT]) -> int:
   152      """
   153      Returns:
   154         The number of bytes of data for a batch.
   155      """
   156      return len(pickle.dumps(batch))
   157  
   158    def get_metrics_namespace(self) -> str:
   159      """
   160      Returns:
   161         A namespace for metrics collected by the RunInference transform.
   162      """
   163      return 'RunInference'
   164  
   165    def get_resource_hints(self) -> dict:
   166      """
   167      Returns:
   168         Resource hints for the transform.
   169      """
   170      return {}
   171  
   172    def batch_elements_kwargs(self) -> Mapping[str, Any]:
   173      """
   174      Returns:
   175         kwargs suitable for beam.BatchElements.
   176      """
   177      return {}
   178  
   179    def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
   180      """Validates inference_args passed in the inference call.
   181  
   182      Because most frameworks do not need extra arguments in their predict() call,
   183      the default behavior is to error out if inference_args are present.
   184      """
   185      if inference_args:
   186        raise ValueError(
   187            'inference_args were provided, but should be None because this '
   188            'framework does not expect extra arguments on inferences.')
   189  
   190    def update_model_path(self, model_path: Optional[str] = None):
   191      """Update the model paths produced by side inputs."""
   192      pass
   193  
   194    def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
   195      """Gets all preprocessing functions to be run before batching/inference.
   196      Functions are in order that they should be applied."""
   197      return []
   198  
   199    def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
   200      """Gets all postprocessing functions to be run after inference.
   201      Functions are in order that they should be applied."""
   202      return []
   203  
   204    def set_environment_vars(self):
   205      """Sets environment variables using a dictionary provided via kwargs.
   206      Keys are the env variable name, and values are the env variable value.
   207      Child ModelHandler classes should set _env_vars via kwargs in __init__,
   208      or else call super().__init__()."""
   209      env_vars = getattr(self, '_env_vars', {})
   210      for env_variable, env_value in env_vars.items():
   211        os.environ[env_variable] = env_value
   212  
   213    def with_preprocess_fn(
   214        self, fn: Callable[[PreProcessT], ExampleT]
   215    ) -> 'ModelHandler[PreProcessT, PredictionT, ModelT, PreProcessT]':
   216      """Returns a new ModelHandler with a preprocessing function
   217      associated with it. The preprocessing function will be run
   218      before batching/inference and should map your input PCollection
   219      to the base ModelHandler's input type. If you apply multiple
   220      preprocessing functions, they will be run on your original
   221      PCollection in order from last applied to first applied."""
   222      return _PreProcessingModelHandler(self, fn)
   223  
   224    def with_postprocess_fn(
   225        self, fn: Callable[[PredictionT], PostProcessT]
   226    ) -> 'ModelHandler[ExampleT, PostProcessT, ModelT, PostProcessT]':
   227      """Returns a new ModelHandler with a postprocessing function
   228      associated with it. The postprocessing function will be run
   229      after inference and should map the base ModelHandler's output
   230      type to your desired output type. If you apply multiple
   231      postprocessing functions, they will be run on your original
   232      inference result in order from first applied to last applied."""
   233      return _PostProcessingModelHandler(self, fn)
   234  
   235    def share_model_across_processes(self) -> bool:
   236      """Returns a boolean representing whether or not a model should
   237      be shared across multiple processes instead of being loaded per process.
   238      This is primary useful for large models that  can't fit multiple copies in
   239      memory. Multi-process support may vary by runner, but this will fallback to
   240      loading per process as necessary. See
   241      https://beam.apache.org/releases/pydoc/current/apache_beam.utils.multi_process_shared.html"""
   242      return False
   243  
   244  
   245  class KeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT],
   246                          ModelHandler[Tuple[KeyT, ExampleT],
   247                                       Tuple[KeyT, PredictionT],
   248                                       ModelT]):
   249    def __init__(self, unkeyed: ModelHandler[ExampleT, PredictionT, ModelT]):
   250      """A ModelHandler that takes keyed examples and returns keyed predictions.
   251  
   252      For example, if the original model is used with RunInference to take a
   253      PCollection[E] to a PCollection[P], this ModelHandler would take a
   254      PCollection[Tuple[K, E]] to a PCollection[Tuple[K, P]], making it possible
   255      to use the key to associate the outputs with the inputs.
   256  
   257      Args:
   258        unkeyed: An implementation of ModelHandler that does not require keys.
   259      """
   260      if len(unkeyed.get_preprocess_fns()) or len(unkeyed.get_postprocess_fns()):
   261        raise Exception(
   262            'Cannot make make an unkeyed model handler with pre or '
   263            'postprocessing functions defined into a keyed model handler. All '
   264            'pre/postprocessing functions must be defined on the outer model'
   265            'handler.')
   266      self._unkeyed = unkeyed
   267      self._env_vars = unkeyed._env_vars
   268  
   269    def load_model(self) -> ModelT:
   270      return self._unkeyed.load_model()
   271  
   272    def run_inference(
   273        self,
   274        batch: Sequence[Tuple[KeyT, ExampleT]],
   275        model: ModelT,
   276        inference_args: Optional[Dict[str, Any]] = None
   277    ) -> Iterable[Tuple[KeyT, PredictionT]]:
   278      keys, unkeyed_batch = zip(*batch)
   279      return zip(
   280          keys, self._unkeyed.run_inference(unkeyed_batch, model, inference_args))
   281  
   282    def get_num_bytes(self, batch: Sequence[Tuple[KeyT, ExampleT]]) -> int:
   283      keys, unkeyed_batch = zip(*batch)
   284      return len(pickle.dumps(keys)) + self._unkeyed.get_num_bytes(unkeyed_batch)
   285  
   286    def get_metrics_namespace(self) -> str:
   287      return self._unkeyed.get_metrics_namespace()
   288  
   289    def get_resource_hints(self):
   290      return self._unkeyed.get_resource_hints()
   291  
   292    def batch_elements_kwargs(self):
   293      return self._unkeyed.batch_elements_kwargs()
   294  
   295    def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
   296      return self._unkeyed.validate_inference_args(inference_args)
   297  
   298    def update_model_path(self, model_path: Optional[str] = None):
   299      return self._unkeyed.update_model_path(model_path=model_path)
   300  
   301    def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
   302      return self._unkeyed.get_preprocess_fns()
   303  
   304    def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
   305      return self._unkeyed.get_postprocess_fns()
   306  
   307    def share_model_across_processes(self) -> bool:
   308      return self._unkeyed.share_model_across_processes()
   309  
   310  
   311  class MaybeKeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT],
   312                               ModelHandler[Union[ExampleT, Tuple[KeyT,
   313                                                                  ExampleT]],
   314                                            Union[PredictionT,
   315                                                  Tuple[KeyT, PredictionT]],
   316                                            ModelT]):
   317    def __init__(self, unkeyed: ModelHandler[ExampleT, PredictionT, ModelT]):
   318      """A ModelHandler that takes examples that might have keys and returns
   319      predictions that might have keys.
   320  
   321      For example, if the original model is used with RunInference to take a
   322      PCollection[E] to a PCollection[P], this ModelHandler would take either
   323      PCollection[E] to a PCollection[P] or PCollection[Tuple[K, E]] to a
   324      PCollection[Tuple[K, P]], depending on the whether the elements are
   325      tuples. This pattern makes it possible to associate the outputs with the
   326      inputs based on the key.
   327  
   328      Note that you cannot use this ModelHandler if E is a tuple type.
   329      In addition, either all examples should be keyed, or none of them.
   330  
   331      Args:
   332        unkeyed: An implementation of ModelHandler that does not require keys.
   333      """
   334      if len(unkeyed.get_preprocess_fns()) or len(unkeyed.get_postprocess_fns()):
   335        raise Exception(
   336            'Cannot make make an unkeyed model handler with pre or '
   337            'postprocessing functions defined into a keyed model handler. All '
   338            'pre/postprocessing functions must be defined on the outer model'
   339            'handler.')
   340      self._unkeyed = unkeyed
   341      self._env_vars = unkeyed._env_vars
   342  
   343    def load_model(self) -> ModelT:
   344      return self._unkeyed.load_model()
   345  
   346    def run_inference(
   347        self,
   348        batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]],
   349        model: ModelT,
   350        inference_args: Optional[Dict[str, Any]] = None
   351    ) -> Union[Iterable[PredictionT], Iterable[Tuple[KeyT, PredictionT]]]:
   352      # Really the input should be
   353      #    Union[Sequence[ExampleT], Sequence[Tuple[KeyT, ExampleT]]]
   354      # but there's not a good way to express (or check) that.
   355      if isinstance(batch[0], tuple):
   356        is_keyed = True
   357        keys, unkeyed_batch = zip(*batch)  # type: ignore[arg-type]
   358      else:
   359        is_keyed = False
   360        unkeyed_batch = batch  # type: ignore[assignment]
   361      unkeyed_results = self._unkeyed.run_inference(
   362          unkeyed_batch, model, inference_args)
   363      if is_keyed:
   364        return zip(keys, unkeyed_results)
   365      else:
   366        return unkeyed_results
   367  
   368    def get_num_bytes(
   369        self, batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]]) -> int:
   370      # MyPy can't follow the branching logic.
   371      if isinstance(batch[0], tuple):
   372        keys, unkeyed_batch = zip(*batch)  # type: ignore[arg-type]
   373        return len(
   374            pickle.dumps(keys)) + self._unkeyed.get_num_bytes(unkeyed_batch)
   375      else:
   376        return self._unkeyed.get_num_bytes(batch)  # type: ignore[arg-type]
   377  
   378    def get_metrics_namespace(self) -> str:
   379      return self._unkeyed.get_metrics_namespace()
   380  
   381    def get_resource_hints(self):
   382      return self._unkeyed.get_resource_hints()
   383  
   384    def batch_elements_kwargs(self):
   385      return self._unkeyed.batch_elements_kwargs()
   386  
   387    def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
   388      return self._unkeyed.validate_inference_args(inference_args)
   389  
   390    def update_model_path(self, model_path: Optional[str] = None):
   391      return self._unkeyed.update_model_path(model_path=model_path)
   392  
   393    def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
   394      return self._unkeyed.get_preprocess_fns()
   395  
   396    def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
   397      return self._unkeyed.get_postprocess_fns()
   398  
   399    def share_model_across_processes(self) -> bool:
   400      return self._unkeyed.share_model_across_processes()
   401  
   402  
   403  class _PreProcessingModelHandler(Generic[ExampleT,
   404                                           PredictionT,
   405                                           ModelT,
   406                                           PreProcessT],
   407                                   ModelHandler[PreProcessT, PredictionT,
   408                                                ModelT]):
   409    def __init__(
   410        self,
   411        base: ModelHandler[ExampleT, PredictionT, ModelT],
   412        preprocess_fn: Callable[[PreProcessT], ExampleT]):
   413      """A ModelHandler that has a preprocessing function associated with it.
   414  
   415      Args:
   416        base: An implementation of the underlying model handler.
   417        preprocess_fn: the preprocessing function to use.
   418      """
   419      self._base = base
   420      self._env_vars = base._env_vars
   421      self._preprocess_fn = preprocess_fn
   422  
   423    def load_model(self) -> ModelT:
   424      return self._base.load_model()
   425  
   426    def run_inference(
   427        self,
   428        batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]],
   429        model: ModelT,
   430        inference_args: Optional[Dict[str, Any]] = None
   431    ) -> Union[Iterable[PredictionT], Iterable[Tuple[KeyT, PredictionT]]]:
   432      return self._base.run_inference(batch, model, inference_args)
   433  
   434    def get_num_bytes(
   435        self, batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]]) -> int:
   436      return self._base.get_num_bytes(batch)
   437  
   438    def get_metrics_namespace(self) -> str:
   439      return self._base.get_metrics_namespace()
   440  
   441    def get_resource_hints(self):
   442      return self._base.get_resource_hints()
   443  
   444    def batch_elements_kwargs(self):
   445      return self._base.batch_elements_kwargs()
   446  
   447    def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
   448      return self._base.validate_inference_args(inference_args)
   449  
   450    def update_model_path(self, model_path: Optional[str] = None):
   451      return self._base.update_model_path(model_path=model_path)
   452  
   453    def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
   454      return [self._preprocess_fn] + self._base.get_preprocess_fns()
   455  
   456    def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
   457      return self._base.get_postprocess_fns()
   458  
   459  
   460  class _PostProcessingModelHandler(Generic[ExampleT,
   461                                            PredictionT,
   462                                            ModelT,
   463                                            PostProcessT],
   464                                    ModelHandler[ExampleT, PostProcessT, ModelT]):
   465    def __init__(
   466        self,
   467        base: ModelHandler[ExampleT, PredictionT, ModelT],
   468        postprocess_fn: Callable[[PredictionT], PostProcessT]):
   469      """A ModelHandler that has a preprocessing function associated with it.
   470  
   471      Args:
   472        base: An implementation of the underlying model handler.
   473        postprocess_fn: the preprocessing function to use.
   474      """
   475      self._base = base
   476      self._env_vars = base._env_vars
   477      self._postprocess_fn = postprocess_fn
   478  
   479    def load_model(self) -> ModelT:
   480      return self._base.load_model()
   481  
   482    def run_inference(
   483        self,
   484        batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]],
   485        model: ModelT,
   486        inference_args: Optional[Dict[str, Any]] = None
   487    ) -> Union[Iterable[PredictionT], Iterable[Tuple[KeyT, PredictionT]]]:
   488      return self._base.run_inference(batch, model, inference_args)
   489  
   490    def get_num_bytes(
   491        self, batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]]) -> int:
   492      return self._base.get_num_bytes(batch)
   493  
   494    def get_metrics_namespace(self) -> str:
   495      return self._base.get_metrics_namespace()
   496  
   497    def get_resource_hints(self):
   498      return self._base.get_resource_hints()
   499  
   500    def batch_elements_kwargs(self):
   501      return self._base.batch_elements_kwargs()
   502  
   503    def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
   504      return self._base.validate_inference_args(inference_args)
   505  
   506    def update_model_path(self, model_path: Optional[str] = None):
   507      return self._base.update_model_path(model_path=model_path)
   508  
   509    def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
   510      return self._base.get_preprocess_fns()
   511  
   512    def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
   513      return self._base.get_postprocess_fns() + [self._postprocess_fn]
   514  
   515  
   516  class RunInference(beam.PTransform[beam.PCollection[ExampleT],
   517                                     beam.PCollection[PredictionT]]):
   518    def __init__(
   519        self,
   520        model_handler: ModelHandler[ExampleT, PredictionT, Any],
   521        clock=time,
   522        inference_args: Optional[Dict[str, Any]] = None,
   523        metrics_namespace: Optional[str] = None,
   524        *,
   525        model_metadata_pcoll: beam.PCollection[ModelMetadata] = None,
   526        watch_model_pattern: Optional[str] = None,
   527        **kwargs):
   528      """
   529      A transform that takes a PCollection of examples (or features) for use
   530      on an ML model. The transform then outputs inferences (or predictions) for
   531      those examples in a PCollection of PredictionResults that contains the input
   532      examples and the output inferences.
   533  
   534      Models for supported frameworks can be loaded using a URI. Supported
   535      services can also be used.
   536  
   537      This transform attempts to batch examples using the beam.BatchElements
   538      transform. Batching can be configured using the ModelHandler.
   539  
   540      Args:
   541          model_handler: An implementation of ModelHandler.
   542          clock: A clock implementing time_ns. *Used for unit testing.*
   543          inference_args: Extra arguments for models whose inference call requires
   544            extra parameters.
   545          metrics_namespace: Namespace of the transform to collect metrics.
   546          model_metadata_pcoll: PCollection that emits Singleton ModelMetadata
   547            containing model path and model name, that is used as a side input
   548            to the _RunInferenceDoFn.
   549          watch_model_pattern: A glob pattern used to watch a directory
   550            for automatic model refresh.
   551      """
   552      self._model_handler = model_handler
   553      self._inference_args = inference_args
   554      self._clock = clock
   555      self._metrics_namespace = metrics_namespace
   556      self._model_metadata_pcoll = model_metadata_pcoll
   557      self._enable_side_input_loading = self._model_metadata_pcoll is not None
   558      self._with_exception_handling = False
   559      self._watch_model_pattern = watch_model_pattern
   560      self._kwargs = kwargs
   561      # Generate a random tag to use for shared.py and multi_process_shared.py to
   562      # allow us to effectively disambiguate in multi-model settings.
   563      self._model_tag = uuid.uuid4().hex
   564  
   565    def _get_model_metadata_pcoll(self, pipeline):
   566      # avoid circular imports.
   567      # pylint: disable=wrong-import-position
   568      from apache_beam.ml.inference.utils import WatchFilePattern
   569      extra_params = {}
   570      if 'interval' in self._kwargs:
   571        extra_params['interval'] = self._kwargs['interval']
   572      if 'stop_timestamp' in self._kwargs:
   573        extra_params['stop_timestamp'] = self._kwargs['stop_timestamp']
   574  
   575      return (
   576          pipeline | WatchFilePattern(
   577              file_pattern=self._watch_model_pattern, **extra_params))
   578  
   579    # TODO(BEAM-14046): Add and link to help documentation.
   580    @classmethod
   581    def from_callable(cls, model_handler_provider, **kwargs):
   582      """Multi-language friendly constructor.
   583  
   584      Use this constructor with fully_qualified_named_transform to
   585      initialize the RunInference transform from PythonCallableSource provided
   586      by foreign SDKs.
   587  
   588      Args:
   589        model_handler_provider: A callable object that returns ModelHandler.
   590        kwargs: Keyword arguments for model_handler_provider.
   591      """
   592      return cls(model_handler_provider(**kwargs))
   593  
   594    def _apply_fns(
   595        self,
   596        pcoll: beam.PCollection,
   597        fns: Iterable[Callable[[Any], Any]],
   598        step_prefix: str) -> Tuple[beam.PCollection, Iterable[beam.PCollection]]:
   599      bad_preprocessed = []
   600      for idx in range(len(fns)):
   601        fn = fns[idx]
   602        if self._with_exception_handling:
   603          pcoll, bad = (pcoll
   604          | f"{step_prefix}-{idx}" >> beam.Map(
   605            fn).with_exception_handling(
   606            exc_class=self._exc_class,
   607            use_subprocess=self._use_subprocess,
   608            threshold=self._threshold))
   609          bad_preprocessed.append(bad)
   610        else:
   611          pcoll = pcoll | f"{step_prefix}-{idx}" >> beam.Map(fn)
   612  
   613      return pcoll, bad_preprocessed
   614  
   615    # TODO(https://github.com/apache/beam/issues/21447): Add batch_size back off
   616    # in the case there are functional reasons large batch sizes cannot be
   617    # handled.
   618    def expand(
   619        self, pcoll: beam.PCollection[ExampleT]) -> beam.PCollection[PredictionT]:
   620      self._model_handler.validate_inference_args(self._inference_args)
   621      # DLQ pcollections
   622      bad_preprocessed = []
   623      bad_inference = None
   624      bad_postprocessed = []
   625      preprocess_fns = self._model_handler.get_preprocess_fns()
   626      postprocess_fns = self._model_handler.get_postprocess_fns()
   627  
   628      pcoll, bad_preprocessed = self._apply_fns(
   629        pcoll, preprocess_fns, 'BeamML_RunInference_Preprocess')
   630  
   631      resource_hints = self._model_handler.get_resource_hints()
   632  
   633      # check for the side input
   634      if self._watch_model_pattern:
   635        self._model_metadata_pcoll = self._get_model_metadata_pcoll(
   636            pcoll.pipeline)
   637  
   638      batched_elements_pcoll = (
   639          pcoll
   640          # TODO(https://github.com/apache/beam/issues/21440): Hook into the
   641          # batching DoFn APIs.
   642          | beam.BatchElements(**self._model_handler.batch_elements_kwargs()))
   643  
   644      run_inference_pardo = beam.ParDo(
   645          _RunInferenceDoFn(
   646              self._model_handler,
   647              self._clock,
   648              self._metrics_namespace,
   649              self._enable_side_input_loading,
   650              self._model_tag),
   651          self._inference_args,
   652          beam.pvalue.AsSingleton(
   653              self._model_metadata_pcoll,
   654          ) if self._enable_side_input_loading else None).with_resource_hints(
   655              **resource_hints)
   656  
   657      if self._with_exception_handling:
   658        results, bad_inference = (
   659            batched_elements_pcoll
   660            | 'BeamML_RunInference' >>
   661            run_inference_pardo.with_exception_handling(
   662            exc_class=self._exc_class,
   663            use_subprocess=self._use_subprocess,
   664            threshold=self._threshold))
   665      else:
   666        results = (
   667            batched_elements_pcoll
   668            | 'BeamML_RunInference' >> run_inference_pardo)
   669  
   670      results, bad_postprocessed = self._apply_fns(
   671        results, postprocess_fns, 'BeamML_RunInference_Postprocess')
   672  
   673      if self._with_exception_handling:
   674        dlq = RunInferenceDLQ(bad_inference, bad_preprocessed, bad_postprocessed)
   675        return results, dlq
   676  
   677      return results
   678  
   679    def with_exception_handling(
   680        self, *, exc_class=Exception, use_subprocess=False, threshold=1):
   681      """Automatically provides a dead letter output for skipping bad records.
   682      This can allow a pipeline to continue successfully rather than fail or
   683      continuously throw errors on retry when bad elements are encountered.
   684  
   685      This returns a tagged output with two PCollections, the first being the
   686      results of successfully processing the input PCollection, and the second
   687      being the set of bad batches of records (those which threw exceptions
   688      during processing) along with information about the errors raised.
   689  
   690      For example, one would write::
   691  
   692          main, other = RunInference(
   693            maybe_error_raising_model_handler
   694          ).with_exception_handling()
   695  
   696      and `main` will be a PCollection of PredictionResults and `other` will
   697      contain a `RunInferenceDLQ` object with PCollections containing failed
   698      records for each failed inference, preprocess operation, or postprocess
   699      operation. To access each collection of failed records, one would write:
   700  
   701          failed_inferences = other.failed_inferences
   702          failed_preprocessing = other.failed_preprocessing
   703          failed_postprocessing = other.failed_postprocessing
   704  
   705      failed_inferences is in the form
   706      PCollection[Tuple[failed batch, exception]].
   707  
   708      failed_preprocessing is in the form
   709      list[PCollection[Tuple[failed record, exception]]]], where each element of
   710      the list corresponds to a preprocess function. These PCollections are
   711      in the same order that the preprocess functions are applied.
   712  
   713      failed_postprocessing is in the form
   714      List[PCollection[Tuple[failed record, exception]]]], where each element of
   715      the list corresponds to a postprocess function. These PCollections are
   716      in the same order that the postprocess functions are applied.
   717  
   718  
   719      Args:
   720        exc_class: An exception class, or tuple of exception classes, to catch.
   721            Optional, defaults to 'Exception'.
   722        use_subprocess: Whether to execute the DoFn logic in a subprocess. This
   723            allows one to recover from errors that can crash the calling process
   724            (e.g. from an underlying library causing a segfault), but is
   725            slower as elements and results must cross a process boundary.  Note
   726            that this starts up a long-running process that is used to handle
   727            all the elements (until hard failure, which should be rare) rather
   728            than a new process per element, so the overhead should be minimal
   729            (and can be amortized if there's any per-process or per-bundle
   730            initialization that needs to be done). Optional, defaults to False.
   731        threshold: An upper bound on the ratio of records that can be bad before
   732            aborting the entire pipeline. Optional, defaults to 1.0 (meaning
   733            up to 100% of records can be bad and the pipeline will still succeed).
   734      """
   735      self._with_exception_handling = True
   736      self._exc_class = exc_class
   737      self._use_subprocess = use_subprocess
   738      self._threshold = threshold
   739      return self
   740  
   741  
   742  class _MetricsCollector:
   743    """A metrics collector that tracks ML related performance and memory usage."""
   744    def __init__(self, namespace: str, prefix: str = ''):
   745      """
   746      Args:
   747       namespace: Namespace for the metrics.
   748       prefix: Unique identifier for metrics, used when models
   749        are updated using side input.
   750      """
   751      # Metrics
   752      if prefix:
   753        prefix = f'{prefix}_'
   754      self._inference_counter = beam.metrics.Metrics.counter(
   755          namespace, prefix + 'num_inferences')
   756      self.failed_batches_counter = beam.metrics.Metrics.counter(
   757          namespace, prefix + 'failed_batches_counter')
   758      self._inference_request_batch_size = beam.metrics.Metrics.distribution(
   759          namespace, prefix + 'inference_request_batch_size')
   760      self._inference_request_batch_byte_size = (
   761          beam.metrics.Metrics.distribution(
   762              namespace, prefix + 'inference_request_batch_byte_size'))
   763      # Batch inference latency in microseconds.
   764      self._inference_batch_latency_micro_secs = (
   765          beam.metrics.Metrics.distribution(
   766              namespace, prefix + 'inference_batch_latency_micro_secs'))
   767      self._model_byte_size = beam.metrics.Metrics.distribution(
   768          namespace, prefix + 'model_byte_size')
   769      # Model load latency in milliseconds.
   770      self._load_model_latency_milli_secs = beam.metrics.Metrics.distribution(
   771          namespace, prefix + 'load_model_latency_milli_secs')
   772  
   773      # Metrics cache
   774      self._load_model_latency_milli_secs_cache = None
   775      self._model_byte_size_cache = None
   776  
   777    def update_metrics_with_cache(self):
   778      if self._load_model_latency_milli_secs_cache is not None:
   779        self._load_model_latency_milli_secs.update(
   780            self._load_model_latency_milli_secs_cache)
   781        self._load_model_latency_milli_secs_cache = None
   782      if self._model_byte_size_cache is not None:
   783        self._model_byte_size.update(self._model_byte_size_cache)
   784        self._model_byte_size_cache = None
   785  
   786    def cache_load_model_metrics(self, load_model_latency_ms, model_byte_size):
   787      self._load_model_latency_milli_secs_cache = load_model_latency_ms
   788      self._model_byte_size_cache = model_byte_size
   789  
   790    def update(
   791        self,
   792        examples_count: int,
   793        examples_byte_size: int,
   794        latency_micro_secs: int):
   795      self._inference_batch_latency_micro_secs.update(latency_micro_secs)
   796      self._inference_counter.inc(examples_count)
   797      self._inference_request_batch_size.update(examples_count)
   798      self._inference_request_batch_byte_size.update(examples_byte_size)
   799  
   800  
   801  class _RunInferenceDoFn(beam.DoFn, Generic[ExampleT, PredictionT]):
   802    def __init__(
   803        self,
   804        model_handler: ModelHandler[ExampleT, PredictionT, Any],
   805        clock,
   806        metrics_namespace,
   807        enable_side_input_loading: bool = False,
   808        model_tag: str = "RunInference"):
   809      """A DoFn implementation generic to frameworks.
   810  
   811        Args:
   812          model_handler: An implementation of ModelHandler.
   813          clock: A clock implementing time_ns. *Used for unit testing.*
   814          metrics_namespace: Namespace of the transform to collect metrics.
   815          enable_side_input_loading: Bool to indicate if model updates
   816              with side inputs.
   817          model_tag: Tag to use to disambiguate models in multi-model settings.
   818      """
   819      self._model_handler = model_handler
   820      self._shared_model_handle = shared.Shared()
   821      self._clock = clock
   822      self._model = None
   823      self._metrics_namespace = metrics_namespace
   824      self._enable_side_input_loading = enable_side_input_loading
   825      self._side_input_path = None
   826      self._model_tag = model_tag
   827  
   828    def _load_model(self, side_input_model_path: Optional[str] = None):
   829      def load():
   830        """Function for constructing shared LoadedModel."""
   831        memory_before = _get_current_process_memory_in_bytes()
   832        start_time = _to_milliseconds(self._clock.time_ns())
   833        self._model_handler.update_model_path(side_input_model_path)
   834        model = self._model_handler.load_model()
   835        end_time = _to_milliseconds(self._clock.time_ns())
   836        memory_after = _get_current_process_memory_in_bytes()
   837        load_model_latency_ms = end_time - start_time
   838        model_byte_size = memory_after - memory_before
   839        self._metrics_collector.cache_load_model_metrics(
   840            load_model_latency_ms, model_byte_size)
   841        return model
   842  
   843      # TODO(https://github.com/apache/beam/issues/21443): Investigate releasing
   844      # model.
   845      if self._model_handler.share_model_across_processes():
   846        model = multi_process_shared.MultiProcessShared(
   847            load, tag=side_input_model_path or self._model_tag).acquire()
   848      else:
   849        model = self._shared_model_handle.acquire(
   850            load, tag=side_input_model_path or self._model_tag)
   851      # since shared_model_handle is shared across threads, the model path
   852      # might not get updated in the model handler
   853      # because we directly get cached weak ref model from shared cache, instead
   854      # of calling load(). For sanity check, call update_model_path again.
   855      self._model_handler.update_model_path(side_input_model_path)
   856      return model
   857  
   858    def get_metrics_collector(self, prefix: str = ''):
   859      """
   860      Args:
   861        prefix: Unique identifier for metrics, used when models
   862        are updated using side input.
   863      """
   864      metrics_namespace = (
   865          self._metrics_namespace) if self._metrics_namespace else (
   866              self._model_handler.get_metrics_namespace())
   867      return _MetricsCollector(metrics_namespace, prefix=prefix)
   868  
   869    def setup(self):
   870      self._metrics_collector = self.get_metrics_collector()
   871      self._model_handler.set_environment_vars()
   872      if not self._enable_side_input_loading:
   873        self._model = self._load_model()
   874  
   875    def update_model(self, side_input_model_path: Optional[str] = None):
   876      self._model = self._load_model(side_input_model_path=side_input_model_path)
   877  
   878    def _run_inference(self, batch, inference_args):
   879      start_time = _to_microseconds(self._clock.time_ns())
   880      try:
   881        result_generator = self._model_handler.run_inference(
   882            batch, self._model, inference_args)
   883      except BaseException as e:
   884        self._metrics_collector.failed_batches_counter.inc()
   885        raise e
   886      predictions = list(result_generator)
   887  
   888      end_time = _to_microseconds(self._clock.time_ns())
   889      inference_latency = end_time - start_time
   890      num_bytes = self._model_handler.get_num_bytes(batch)
   891      num_elements = len(batch)
   892      self._metrics_collector.update(num_elements, num_bytes, inference_latency)
   893  
   894      return predictions
   895  
   896    def process(
   897        self, batch, inference_args, si_model_metadata: Optional[ModelMetadata]):
   898      """
   899      When side input is enabled:
   900        The method checks if the side input model has been updated, and if so,
   901        updates the model and runs inference on the batch of data. If the
   902        side input is empty or the model has not been updated, the method
   903        simply runs inference on the batch of data.
   904      """
   905      if si_model_metadata:
   906        if isinstance(si_model_metadata, beam.pvalue.EmptySideInput):
   907          self.update_model(side_input_model_path=None)
   908          return self._run_inference(batch, inference_args)
   909        elif self._side_input_path != si_model_metadata.model_id:
   910          self._side_input_path = si_model_metadata.model_id
   911          self._metrics_collector = self.get_metrics_collector(
   912              prefix=si_model_metadata.model_name)
   913          with threading.Lock():
   914            self.update_model(si_model_metadata.model_id)
   915            return self._run_inference(batch, inference_args)
   916      return self._run_inference(batch, inference_args)
   917  
   918    def finish_bundle(self):
   919      # TODO(https://github.com/apache/beam/issues/21435): Figure out why there
   920      # is a cache.
   921      self._metrics_collector.update_metrics_with_cache()
   922  
   923  
   924  def _is_darwin() -> bool:
   925    return sys.platform == 'darwin'
   926  
   927  
   928  def _get_current_process_memory_in_bytes():
   929    """
   930    Returns:
   931      memory usage in bytes.
   932    """
   933  
   934    if resource is not None:
   935      usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
   936      if _is_darwin():
   937        return usage
   938      return usage * 1024
   939    else:
   940      logging.warning(
   941          'Resource module is not available for current platform, '
   942          'memory usage cannot be fetched.')
   943    return 0