github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/ml/inference/tensorrt_inference.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  # pytype: skip-file
    19  
    20  from __future__ import annotations
    21  
    22  import logging
    23  import threading
    24  from typing import Any
    25  from typing import Callable
    26  from typing import Dict
    27  from typing import Iterable
    28  from typing import Optional
    29  from typing import Sequence
    30  from typing import Tuple
    31  
    32  import numpy as np
    33  
    34  from apache_beam.io.filesystems import FileSystems
    35  from apache_beam.ml.inference import utils
    36  from apache_beam.ml.inference.base import ModelHandler
    37  from apache_beam.ml.inference.base import PredictionResult
    38  
    39  LOGGER = logging.getLogger("TensorRTEngineHandlerNumPy")
    40  # This try/catch block allows users to submit jobs from a machine without
    41  # GPU and other dependencies (tensorrt, cuda, etc.) at job submission time.
    42  try:
    43    import tensorrt as trt
    44    TRT_LOGGER = trt.Logger(trt.Logger.INFO)
    45    trt.init_libnvinfer_plugins(TRT_LOGGER, namespace="")
    46    LOGGER.info('tensorrt module successfully imported.')
    47  except ModuleNotFoundError:
    48    TRT_LOGGER = None
    49    msg = 'tensorrt module was not found. This is ok as long as the specified ' \
    50      'runner has tensorrt dependencies installed.'
    51    LOGGER.warning(msg)
    52  
    53  
    54  def _load_engine(engine_path):
    55    import tensorrt as trt
    56    file = FileSystems.open(engine_path, 'rb')
    57    runtime = trt.Runtime(TRT_LOGGER)
    58    engine = runtime.deserialize_cuda_engine(file.read())
    59    assert engine
    60    return engine
    61  
    62  
    63  def _load_onnx(onnx_path):
    64    import tensorrt as trt
    65    builder = trt.Builder(TRT_LOGGER)
    66    network = builder.create_network(
    67        flags=1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
    68    parser = trt.OnnxParser(network, TRT_LOGGER)
    69    with FileSystems.open(onnx_path) as f:
    70      if not parser.parse(f.read()):
    71        LOGGER.error("Failed to load ONNX file: %s", onnx_path)
    72        for error in range(parser.num_errors):
    73          LOGGER.error(parser.get_error(error))
    74        raise ValueError(f"Failed to load ONNX file: {onnx_path}")
    75    return network, builder
    76  
    77  
    78  def _build_engine(network, builder):
    79    import tensorrt as trt
    80    config = builder.create_builder_config()
    81    runtime = trt.Runtime(TRT_LOGGER)
    82    plan = builder.build_serialized_network(network, config)
    83    engine = runtime.deserialize_cuda_engine(plan)
    84    builder.reset()
    85    return engine
    86  
    87  
    88  def _assign_or_fail(args):
    89    """CUDA error checking."""
    90    from cuda import cuda
    91    err, ret = args[0], args[1:]
    92    if isinstance(err, cuda.CUresult):
    93      if err != cuda.CUresult.CUDA_SUCCESS:
    94        raise RuntimeError("Cuda Error: {}".format(err))
    95    else:
    96      raise RuntimeError("Unknown error type: {}".format(err))
    97    # Special case so that no unpacking is needed at call-site.
    98    if len(ret) == 1:
    99      return ret[0]
   100    return ret
   101  
   102  
   103  class TensorRTEngine:
   104    def __init__(self, engine: trt.ICudaEngine):
   105      """Implementation of the TensorRTEngine class which handles
   106      allocations associated with TensorRT engine.
   107  
   108      Example Usage::
   109  
   110        TensorRTEngine(engine)
   111  
   112      Args:
   113        engine: trt.ICudaEngine object that contains TensorRT engine
   114      """
   115      from cuda import cuda
   116      import tensorrt as trt
   117      self.engine = engine
   118      self.context = engine.create_execution_context()
   119      self.context_lock = threading.RLock()
   120      self.inputs = []
   121      self.outputs = []
   122      self.gpu_allocations = []
   123      self.cpu_allocations = []
   124  
   125      # TODO(https://github.com/NVIDIA/TensorRT/issues/2557):
   126      # Clean up when fixed upstream.
   127      try:
   128        _ = np.bool  # type: ignore
   129      except AttributeError:
   130        # numpy >= 1.24.0
   131        np.bool = np.bool_  # type: ignore
   132  
   133      # Setup I/O bindings.
   134      for i in range(self.engine.num_bindings):
   135        name = self.engine.get_binding_name(i)
   136        dtype = self.engine.get_binding_dtype(i)
   137        shape = self.engine.get_binding_shape(i)
   138        size = trt.volume(shape) * dtype.itemsize
   139        allocation = _assign_or_fail(cuda.cuMemAlloc(size))
   140        binding = {
   141            'index': i,
   142            'name': name,
   143            'dtype': np.dtype(trt.nptype(dtype)),
   144            'shape': list(shape),
   145            'allocation': allocation,
   146            'size': size
   147        }
   148        self.gpu_allocations.append(allocation)
   149        if self.engine.binding_is_input(i):
   150          self.inputs.append(binding)
   151        else:
   152          self.outputs.append(binding)
   153  
   154      assert self.context
   155      assert len(self.inputs) > 0
   156      assert len(self.outputs) > 0
   157      assert len(self.gpu_allocations) > 0
   158  
   159      for output in self.outputs:
   160        self.cpu_allocations.append(np.zeros(output['shape'], output['dtype']))
   161      # Create CUDA Stream.
   162      self.stream = _assign_or_fail(cuda.cuStreamCreate(0))
   163  
   164    def get_engine_attrs(self):
   165      """Returns TensorRT engine attributes."""
   166      return (
   167          self.engine,
   168          self.context,
   169          self.context_lock,
   170          self.inputs,
   171          self.outputs,
   172          self.gpu_allocations,
   173          self.cpu_allocations,
   174          self.stream)
   175  
   176  
   177  TensorRTInferenceFn = Callable[
   178      [Sequence[np.ndarray], TensorRTEngine, Optional[Dict[str, Any]]],
   179      Iterable[PredictionResult]]
   180  
   181  
   182  def _default_tensorRT_inference_fn(
   183      batch: Sequence[np.ndarray],
   184      engine: TensorRTEngine,
   185      inference_args: Optional[Dict[str,
   186                                    Any]] = None) -> Iterable[PredictionResult]:
   187    from cuda import cuda
   188    (
   189        engine,
   190        context,
   191        context_lock,
   192        inputs,
   193        outputs,
   194        gpu_allocations,
   195        cpu_allocations,
   196        stream) = engine.get_engine_attrs()
   197  
   198    # Process I/O and execute the network
   199    with context_lock:
   200      _assign_or_fail(
   201          cuda.cuMemcpyHtoDAsync(
   202              inputs[0]['allocation'],
   203              np.ascontiguousarray(batch),
   204              inputs[0]['size'],
   205              stream))
   206      context.execute_async_v2(gpu_allocations, stream)
   207      for output in range(len(cpu_allocations)):
   208        _assign_or_fail(
   209            cuda.cuMemcpyDtoHAsync(
   210                cpu_allocations[output],
   211                outputs[output]['allocation'],
   212                outputs[output]['size'],
   213                stream))
   214      _assign_or_fail(cuda.cuStreamSynchronize(stream))
   215  
   216      predictions = []
   217      for idx in range(len(batch)):
   218        predictions.append([prediction[idx] for prediction in cpu_allocations])
   219  
   220      return utils._convert_to_result(batch, predictions)
   221  
   222  
   223  class TensorRTEngineHandlerNumPy(ModelHandler[np.ndarray,
   224                                                PredictionResult,
   225                                                TensorRTEngine]):
   226    def __init__(
   227        self,
   228        min_batch_size: int,
   229        max_batch_size: int,
   230        *,
   231        inference_fn: TensorRTInferenceFn = _default_tensorRT_inference_fn,
   232        **kwargs):
   233      """Implementation of the ModelHandler interface for TensorRT.
   234  
   235      Example Usage::
   236  
   237        pcoll | RunInference(
   238            TensorRTEngineHandlerNumPy(
   239              min_batch_size=1,
   240              max_batch_size=1,
   241              engine_path="my_uri"))
   242  
   243      **NOTE:** This API and its implementation are under development and
   244      do not provide backward compatibility guarantees.
   245  
   246      Args:
   247        min_batch_size: minimum accepted batch size.
   248        max_batch_size: maximum accepted batch size.
   249        inference_fn: the inference function to use on RunInference calls.
   250          default: _default_tensorRT_inference_fn
   251        kwargs: Additional arguments like 'engine_path' and 'onnx_path' are
   252          currently supported. 'env_vars' can be used to set environment variables
   253          before loading the model.
   254  
   255      See https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/
   256      for details
   257      """
   258      self.min_batch_size = min_batch_size
   259      self.max_batch_size = max_batch_size
   260      self.inference_fn = inference_fn
   261      if 'engine_path' in kwargs:
   262        self.engine_path = kwargs.get('engine_path')
   263      elif 'onnx_path' in kwargs:
   264        self.onnx_path = kwargs.get('onnx_path')
   265      self._env_vars = kwargs.get('env_vars', {})
   266  
   267    def batch_elements_kwargs(self):
   268      """Sets min_batch_size and max_batch_size of a TensorRT engine."""
   269      return {
   270          'min_batch_size': self.min_batch_size,
   271          'max_batch_size': self.max_batch_size
   272      }
   273  
   274    def load_model(self) -> TensorRTEngine:
   275      """Loads and initializes a TensorRT engine for processing."""
   276      engine = _load_engine(self.engine_path)
   277      return TensorRTEngine(engine)
   278  
   279    def load_onnx(self) -> Tuple[trt.INetworkDefinition, trt.Builder]:
   280      """Loads and parses an onnx model for processing."""
   281      return _load_onnx(self.onnx_path)
   282  
   283    def build_engine(
   284        self, network: trt.INetworkDefinition,
   285        builder: trt.Builder) -> TensorRTEngine:
   286      """Build an engine according to parsed/created network."""
   287      engine = _build_engine(network, builder)
   288      return TensorRTEngine(engine)
   289  
   290    def run_inference(
   291        self,
   292        batch: Sequence[np.ndarray],
   293        engine: TensorRTEngine,
   294        inference_args: Optional[Dict[str, Any]] = None
   295    ) -> Iterable[PredictionResult]:
   296      """
   297      Runs inferences on a batch of Tensors and returns an Iterable of
   298      TensorRT Predictions.
   299  
   300      Args:
   301        batch: A np.ndarray or a np.ndarray that represents a concatenation
   302          of multiple arrays as a batch.
   303        engine: A TensorRT engine.
   304        inference_args: Any additional arguments for an inference
   305          that are not applicable to TensorRT.
   306  
   307      Returns:
   308        An Iterable of type PredictionResult.
   309      """
   310      return self.inference_fn(batch, engine, inference_args)
   311  
   312    def get_num_bytes(self, batch: Sequence[np.ndarray]) -> int:
   313      """
   314      Returns:
   315        The number of bytes of data for a batch of Tensors.
   316      """
   317      return sum((np_array.itemsize for np_array in batch))
   318  
   319    def get_metrics_namespace(self) -> str:
   320      """
   321      Returns a namespace for metrics collected by the RunInference transform.
   322      """
   323      return 'BeamML_TensorRT'