github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/ml/inference/tensorrt_inference.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 # pytype: skip-file 19 20 from __future__ import annotations 21 22 import logging 23 import threading 24 from typing import Any 25 from typing import Callable 26 from typing import Dict 27 from typing import Iterable 28 from typing import Optional 29 from typing import Sequence 30 from typing import Tuple 31 32 import numpy as np 33 34 from apache_beam.io.filesystems import FileSystems 35 from apache_beam.ml.inference import utils 36 from apache_beam.ml.inference.base import ModelHandler 37 from apache_beam.ml.inference.base import PredictionResult 38 39 LOGGER = logging.getLogger("TensorRTEngineHandlerNumPy") 40 # This try/catch block allows users to submit jobs from a machine without 41 # GPU and other dependencies (tensorrt, cuda, etc.) at job submission time. 42 try: 43 import tensorrt as trt 44 TRT_LOGGER = trt.Logger(trt.Logger.INFO) 45 trt.init_libnvinfer_plugins(TRT_LOGGER, namespace="") 46 LOGGER.info('tensorrt module successfully imported.') 47 except ModuleNotFoundError: 48 TRT_LOGGER = None 49 msg = 'tensorrt module was not found. This is ok as long as the specified ' \ 50 'runner has tensorrt dependencies installed.' 51 LOGGER.warning(msg) 52 53 54 def _load_engine(engine_path): 55 import tensorrt as trt 56 file = FileSystems.open(engine_path, 'rb') 57 runtime = trt.Runtime(TRT_LOGGER) 58 engine = runtime.deserialize_cuda_engine(file.read()) 59 assert engine 60 return engine 61 62 63 def _load_onnx(onnx_path): 64 import tensorrt as trt 65 builder = trt.Builder(TRT_LOGGER) 66 network = builder.create_network( 67 flags=1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) 68 parser = trt.OnnxParser(network, TRT_LOGGER) 69 with FileSystems.open(onnx_path) as f: 70 if not parser.parse(f.read()): 71 LOGGER.error("Failed to load ONNX file: %s", onnx_path) 72 for error in range(parser.num_errors): 73 LOGGER.error(parser.get_error(error)) 74 raise ValueError(f"Failed to load ONNX file: {onnx_path}") 75 return network, builder 76 77 78 def _build_engine(network, builder): 79 import tensorrt as trt 80 config = builder.create_builder_config() 81 runtime = trt.Runtime(TRT_LOGGER) 82 plan = builder.build_serialized_network(network, config) 83 engine = runtime.deserialize_cuda_engine(plan) 84 builder.reset() 85 return engine 86 87 88 def _assign_or_fail(args): 89 """CUDA error checking.""" 90 from cuda import cuda 91 err, ret = args[0], args[1:] 92 if isinstance(err, cuda.CUresult): 93 if err != cuda.CUresult.CUDA_SUCCESS: 94 raise RuntimeError("Cuda Error: {}".format(err)) 95 else: 96 raise RuntimeError("Unknown error type: {}".format(err)) 97 # Special case so that no unpacking is needed at call-site. 98 if len(ret) == 1: 99 return ret[0] 100 return ret 101 102 103 class TensorRTEngine: 104 def __init__(self, engine: trt.ICudaEngine): 105 """Implementation of the TensorRTEngine class which handles 106 allocations associated with TensorRT engine. 107 108 Example Usage:: 109 110 TensorRTEngine(engine) 111 112 Args: 113 engine: trt.ICudaEngine object that contains TensorRT engine 114 """ 115 from cuda import cuda 116 import tensorrt as trt 117 self.engine = engine 118 self.context = engine.create_execution_context() 119 self.context_lock = threading.RLock() 120 self.inputs = [] 121 self.outputs = [] 122 self.gpu_allocations = [] 123 self.cpu_allocations = [] 124 125 # TODO(https://github.com/NVIDIA/TensorRT/issues/2557): 126 # Clean up when fixed upstream. 127 try: 128 _ = np.bool # type: ignore 129 except AttributeError: 130 # numpy >= 1.24.0 131 np.bool = np.bool_ # type: ignore 132 133 # Setup I/O bindings. 134 for i in range(self.engine.num_bindings): 135 name = self.engine.get_binding_name(i) 136 dtype = self.engine.get_binding_dtype(i) 137 shape = self.engine.get_binding_shape(i) 138 size = trt.volume(shape) * dtype.itemsize 139 allocation = _assign_or_fail(cuda.cuMemAlloc(size)) 140 binding = { 141 'index': i, 142 'name': name, 143 'dtype': np.dtype(trt.nptype(dtype)), 144 'shape': list(shape), 145 'allocation': allocation, 146 'size': size 147 } 148 self.gpu_allocations.append(allocation) 149 if self.engine.binding_is_input(i): 150 self.inputs.append(binding) 151 else: 152 self.outputs.append(binding) 153 154 assert self.context 155 assert len(self.inputs) > 0 156 assert len(self.outputs) > 0 157 assert len(self.gpu_allocations) > 0 158 159 for output in self.outputs: 160 self.cpu_allocations.append(np.zeros(output['shape'], output['dtype'])) 161 # Create CUDA Stream. 162 self.stream = _assign_or_fail(cuda.cuStreamCreate(0)) 163 164 def get_engine_attrs(self): 165 """Returns TensorRT engine attributes.""" 166 return ( 167 self.engine, 168 self.context, 169 self.context_lock, 170 self.inputs, 171 self.outputs, 172 self.gpu_allocations, 173 self.cpu_allocations, 174 self.stream) 175 176 177 TensorRTInferenceFn = Callable[ 178 [Sequence[np.ndarray], TensorRTEngine, Optional[Dict[str, Any]]], 179 Iterable[PredictionResult]] 180 181 182 def _default_tensorRT_inference_fn( 183 batch: Sequence[np.ndarray], 184 engine: TensorRTEngine, 185 inference_args: Optional[Dict[str, 186 Any]] = None) -> Iterable[PredictionResult]: 187 from cuda import cuda 188 ( 189 engine, 190 context, 191 context_lock, 192 inputs, 193 outputs, 194 gpu_allocations, 195 cpu_allocations, 196 stream) = engine.get_engine_attrs() 197 198 # Process I/O and execute the network 199 with context_lock: 200 _assign_or_fail( 201 cuda.cuMemcpyHtoDAsync( 202 inputs[0]['allocation'], 203 np.ascontiguousarray(batch), 204 inputs[0]['size'], 205 stream)) 206 context.execute_async_v2(gpu_allocations, stream) 207 for output in range(len(cpu_allocations)): 208 _assign_or_fail( 209 cuda.cuMemcpyDtoHAsync( 210 cpu_allocations[output], 211 outputs[output]['allocation'], 212 outputs[output]['size'], 213 stream)) 214 _assign_or_fail(cuda.cuStreamSynchronize(stream)) 215 216 predictions = [] 217 for idx in range(len(batch)): 218 predictions.append([prediction[idx] for prediction in cpu_allocations]) 219 220 return utils._convert_to_result(batch, predictions) 221 222 223 class TensorRTEngineHandlerNumPy(ModelHandler[np.ndarray, 224 PredictionResult, 225 TensorRTEngine]): 226 def __init__( 227 self, 228 min_batch_size: int, 229 max_batch_size: int, 230 *, 231 inference_fn: TensorRTInferenceFn = _default_tensorRT_inference_fn, 232 **kwargs): 233 """Implementation of the ModelHandler interface for TensorRT. 234 235 Example Usage:: 236 237 pcoll | RunInference( 238 TensorRTEngineHandlerNumPy( 239 min_batch_size=1, 240 max_batch_size=1, 241 engine_path="my_uri")) 242 243 **NOTE:** This API and its implementation are under development and 244 do not provide backward compatibility guarantees. 245 246 Args: 247 min_batch_size: minimum accepted batch size. 248 max_batch_size: maximum accepted batch size. 249 inference_fn: the inference function to use on RunInference calls. 250 default: _default_tensorRT_inference_fn 251 kwargs: Additional arguments like 'engine_path' and 'onnx_path' are 252 currently supported. 'env_vars' can be used to set environment variables 253 before loading the model. 254 255 See https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/ 256 for details 257 """ 258 self.min_batch_size = min_batch_size 259 self.max_batch_size = max_batch_size 260 self.inference_fn = inference_fn 261 if 'engine_path' in kwargs: 262 self.engine_path = kwargs.get('engine_path') 263 elif 'onnx_path' in kwargs: 264 self.onnx_path = kwargs.get('onnx_path') 265 self._env_vars = kwargs.get('env_vars', {}) 266 267 def batch_elements_kwargs(self): 268 """Sets min_batch_size and max_batch_size of a TensorRT engine.""" 269 return { 270 'min_batch_size': self.min_batch_size, 271 'max_batch_size': self.max_batch_size 272 } 273 274 def load_model(self) -> TensorRTEngine: 275 """Loads and initializes a TensorRT engine for processing.""" 276 engine = _load_engine(self.engine_path) 277 return TensorRTEngine(engine) 278 279 def load_onnx(self) -> Tuple[trt.INetworkDefinition, trt.Builder]: 280 """Loads and parses an onnx model for processing.""" 281 return _load_onnx(self.onnx_path) 282 283 def build_engine( 284 self, network: trt.INetworkDefinition, 285 builder: trt.Builder) -> TensorRTEngine: 286 """Build an engine according to parsed/created network.""" 287 engine = _build_engine(network, builder) 288 return TensorRTEngine(engine) 289 290 def run_inference( 291 self, 292 batch: Sequence[np.ndarray], 293 engine: TensorRTEngine, 294 inference_args: Optional[Dict[str, Any]] = None 295 ) -> Iterable[PredictionResult]: 296 """ 297 Runs inferences on a batch of Tensors and returns an Iterable of 298 TensorRT Predictions. 299 300 Args: 301 batch: A np.ndarray or a np.ndarray that represents a concatenation 302 of multiple arrays as a batch. 303 engine: A TensorRT engine. 304 inference_args: Any additional arguments for an inference 305 that are not applicable to TensorRT. 306 307 Returns: 308 An Iterable of type PredictionResult. 309 """ 310 return self.inference_fn(batch, engine, inference_args) 311 312 def get_num_bytes(self, batch: Sequence[np.ndarray]) -> int: 313 """ 314 Returns: 315 The number of bytes of data for a batch of Tensors. 316 """ 317 return sum((np_array.itemsize for np_array in batch)) 318 319 def get_metrics_namespace(self) -> str: 320 """ 321 Returns a namespace for metrics collected by the RunInference transform. 322 """ 323 return 'BeamML_TensorRT'