github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/ml/inference/onnx_inference.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 from typing import Any 19 from typing import Callable 20 from typing import Dict 21 from typing import Iterable 22 from typing import Optional 23 from typing import Sequence 24 25 import numpy 26 27 import onnx 28 import onnxruntime as ort 29 from apache_beam.io.filesystems import FileSystems 30 from apache_beam.ml.inference import utils 31 from apache_beam.ml.inference.base import ModelHandler 32 from apache_beam.ml.inference.base import PredictionResult 33 34 __all__ = ['OnnxModelHandlerNumpy'] 35 36 NumpyInferenceFn = Callable[ 37 [Sequence[numpy.ndarray], ort.InferenceSession, Optional[Dict[str, Any]]], 38 Iterable[PredictionResult]] 39 40 41 def default_numpy_inference_fn( 42 inference_session: ort.InferenceSession, 43 batch: Sequence[numpy.ndarray], 44 inference_args: Optional[Dict[str, Any]] = None) -> Any: 45 ort_inputs = { 46 inference_session.get_inputs()[0].name: numpy.stack(batch, axis=0) 47 } 48 if inference_args: 49 ort_inputs = {**ort_inputs, **inference_args} 50 ort_outs = inference_session.run(None, ort_inputs)[0] 51 return ort_outs 52 53 54 class OnnxModelHandlerNumpy(ModelHandler[numpy.ndarray, 55 PredictionResult, 56 ort.InferenceSession]): 57 def __init__( #pylint: disable=dangerous-default-value 58 self, 59 model_uri: str, 60 session_options=None, 61 providers=['CUDAExecutionProvider', 'CPUExecutionProvider'], 62 provider_options=None, 63 *, 64 inference_fn: NumpyInferenceFn = default_numpy_inference_fn, 65 **kwargs): 66 """ Implementation of the ModelHandler interface for onnx 67 using numpy arrays as input. 68 Note that inputs to ONNXModelHandler should be of the same sizes 69 70 Example Usage:: 71 72 pcoll | RunInference(OnnxModelHandler(model_uri="my_uri")) 73 74 Args: 75 model_uri: The URI to where the model is saved. 76 inference_fn: The inference function to use on RunInference calls. 77 default=default_numpy_inference_fn 78 kwargs: 'env_vars' can be used to set environment variables 79 before loading the model. 80 """ 81 self._model_uri = model_uri 82 self._session_options = session_options 83 self._providers = providers 84 self._provider_options = provider_options 85 self._model_inference_fn = inference_fn 86 self._env_vars = kwargs.get('env_vars', {}) 87 88 def load_model(self) -> ort.InferenceSession: 89 """Loads and initializes an onnx inference session for processing.""" 90 # when path is remote, we should first load into memory then deserialize 91 f = FileSystems.open(self._model_uri, "rb") 92 model_proto = onnx.load(f) 93 model_proto_bytes = onnx._serialize(model_proto) 94 ort_session = ort.InferenceSession( 95 model_proto_bytes, 96 sess_options=self._session_options, 97 providers=self._providers, 98 provider_options=self._provider_options) 99 return ort_session 100 101 def run_inference( 102 self, 103 batch: Sequence[numpy.ndarray], 104 inference_session: ort.InferenceSession, 105 inference_args: Optional[Dict[str, Any]] = None 106 ) -> Iterable[PredictionResult]: 107 """Runs inferences on a batch of numpy arrays. 108 109 Args: 110 batch: A sequence of examples as numpy arrays. They should 111 be single examples. 112 inference_session: An onnx inference session. 113 Must be runnable with input x where x is sequence of numpy array 114 inference_args: Any additional arguments for an inference. 115 116 Returns: 117 An Iterable of type PredictionResult. 118 """ 119 predictions = self._model_inference_fn( 120 inference_session, batch, inference_args) 121 122 return utils._convert_to_result(batch, predictions) 123 124 def get_num_bytes(self, batch: Sequence[numpy.ndarray]) -> int: 125 """ 126 Returns: 127 The number of bytes of data for a batch. 128 """ 129 return sum((np_array.itemsize for np_array in batch)) 130 131 def get_metrics_namespace(self) -> str: 132 """ 133 Returns: 134 A namespace for metrics collected by the RunInference transform. 135 """ 136 return 'BeamML_Onnx'