github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/ml/inference/onnx_inference.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  from typing import Any
    19  from typing import Callable
    20  from typing import Dict
    21  from typing import Iterable
    22  from typing import Optional
    23  from typing import Sequence
    24  
    25  import numpy
    26  
    27  import onnx
    28  import onnxruntime as ort
    29  from apache_beam.io.filesystems import FileSystems
    30  from apache_beam.ml.inference import utils
    31  from apache_beam.ml.inference.base import ModelHandler
    32  from apache_beam.ml.inference.base import PredictionResult
    33  
    34  __all__ = ['OnnxModelHandlerNumpy']
    35  
    36  NumpyInferenceFn = Callable[
    37      [Sequence[numpy.ndarray], ort.InferenceSession, Optional[Dict[str, Any]]],
    38      Iterable[PredictionResult]]
    39  
    40  
    41  def default_numpy_inference_fn(
    42      inference_session: ort.InferenceSession,
    43      batch: Sequence[numpy.ndarray],
    44      inference_args: Optional[Dict[str, Any]] = None) -> Any:
    45    ort_inputs = {
    46        inference_session.get_inputs()[0].name: numpy.stack(batch, axis=0)
    47    }
    48    if inference_args:
    49      ort_inputs = {**ort_inputs, **inference_args}
    50    ort_outs = inference_session.run(None, ort_inputs)[0]
    51    return ort_outs
    52  
    53  
    54  class OnnxModelHandlerNumpy(ModelHandler[numpy.ndarray,
    55                                           PredictionResult,
    56                                           ort.InferenceSession]):
    57    def __init__( #pylint: disable=dangerous-default-value
    58        self,
    59        model_uri: str,
    60        session_options=None,
    61        providers=['CUDAExecutionProvider', 'CPUExecutionProvider'],
    62        provider_options=None,
    63        *,
    64        inference_fn: NumpyInferenceFn = default_numpy_inference_fn,
    65        **kwargs):
    66      """ Implementation of the ModelHandler interface for onnx
    67      using numpy arrays as input.
    68      Note that inputs to ONNXModelHandler should be of the same sizes
    69  
    70      Example Usage::
    71  
    72        pcoll | RunInference(OnnxModelHandler(model_uri="my_uri"))
    73  
    74      Args:
    75        model_uri: The URI to where the model is saved.
    76        inference_fn: The inference function to use on RunInference calls.
    77          default=default_numpy_inference_fn
    78        kwargs: 'env_vars' can be used to set environment variables
    79          before loading the model.
    80      """
    81      self._model_uri = model_uri
    82      self._session_options = session_options
    83      self._providers = providers
    84      self._provider_options = provider_options
    85      self._model_inference_fn = inference_fn
    86      self._env_vars = kwargs.get('env_vars', {})
    87  
    88    def load_model(self) -> ort.InferenceSession:
    89      """Loads and initializes an onnx inference session for processing."""
    90      # when path is remote, we should first load into memory then deserialize
    91      f = FileSystems.open(self._model_uri, "rb")
    92      model_proto = onnx.load(f)
    93      model_proto_bytes = onnx._serialize(model_proto)
    94      ort_session = ort.InferenceSession(
    95          model_proto_bytes,
    96          sess_options=self._session_options,
    97          providers=self._providers,
    98          provider_options=self._provider_options)
    99      return ort_session
   100  
   101    def run_inference(
   102        self,
   103        batch: Sequence[numpy.ndarray],
   104        inference_session: ort.InferenceSession,
   105        inference_args: Optional[Dict[str, Any]] = None
   106    ) -> Iterable[PredictionResult]:
   107      """Runs inferences on a batch of numpy arrays.
   108  
   109      Args:
   110        batch: A sequence of examples as numpy arrays. They should
   111          be single examples.
   112        inference_session: An onnx inference session.
   113          Must be runnable with input x where x is sequence of numpy array
   114        inference_args: Any additional arguments for an inference.
   115  
   116      Returns:
   117        An Iterable of type PredictionResult.
   118      """
   119      predictions = self._model_inference_fn(
   120          inference_session, batch, inference_args)
   121  
   122      return utils._convert_to_result(batch, predictions)
   123  
   124    def get_num_bytes(self, batch: Sequence[numpy.ndarray]) -> int:
   125      """
   126      Returns:
   127        The number of bytes of data for a batch.
   128      """
   129      return sum((np_array.itemsize for np_array in batch))
   130  
   131    def get_metrics_namespace(self) -> str:
   132      """
   133      Returns:
   134         A namespace for metrics collected by the RunInference transform.
   135      """
   136      return 'BeamML_Onnx'