github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/ml/inference/tensorrt_inference_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  # pytype: skip-file
    19  
    20  import os
    21  import unittest
    22  
    23  import numpy as np
    24  import pytest
    25  
    26  import apache_beam as beam
    27  from apache_beam.testing.test_pipeline import TestPipeline
    28  from apache_beam.testing.util import assert_that
    29  from apache_beam.testing.util import equal_to
    30  
    31  # Protect against environments where TensorRT python library is not available.
    32  # pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports
    33  try:
    34    import tensorrt as trt
    35    from apache_beam.ml.inference.base import PredictionResult, RunInference
    36    from apache_beam.ml.inference.tensorrt_inference import \
    37        TensorRTEngineHandlerNumPy
    38  except ImportError:
    39    raise unittest.SkipTest('TensorRT dependencies are not installed')
    40  
    41  try:
    42    from apache_beam.io.gcp.gcsfilesystem import GCSFileSystem
    43  except ImportError:
    44    GCSFileSystem = None  # type: ignore
    45  
    46  LOGGER = trt.Logger(trt.Logger.INFO)
    47  
    48  SINGLE_FEATURE_EXAMPLES = [
    49      np.array(1, dtype=np.float32),
    50      np.array(5, dtype=np.float32),
    51      np.array(-3, dtype=np.float32),
    52      np.array(10.0, dtype=np.float32)
    53  ]
    54  
    55  SINGLE_FEATURE_PREDICTIONS = [
    56      PredictionResult(ex, pred) for ex,
    57      pred in zip(
    58          SINGLE_FEATURE_EXAMPLES,
    59          [[np.array([example * 2.0 + 0.5], dtype=np.float32)]
    60           for example in SINGLE_FEATURE_EXAMPLES])
    61  ]
    62  
    63  SINGLE_FEATURE_CUSTOM_PREDICTIONS = [
    64      PredictionResult(ex, pred) for ex,
    65      pred in zip(
    66          SINGLE_FEATURE_EXAMPLES,
    67          [[np.array([(example * 2.0 + 0.5) * 2], dtype=np.float32)]
    68           for example in SINGLE_FEATURE_EXAMPLES])
    69  ]
    70  
    71  TWO_FEATURES_EXAMPLES = [
    72      np.array([1, 5], dtype=np.float32),
    73      np.array([3, 10], dtype=np.float32),
    74      np.array([-14, 0], dtype=np.float32),
    75      np.array([0.5, 0.5], dtype=np.float32)
    76  ]
    77  
    78  TWO_FEATURES_PREDICTIONS = [
    79      PredictionResult(ex, pred) for ex,
    80      pred in zip(
    81          TWO_FEATURES_EXAMPLES,
    82          [[
    83              np.array([example[0] * 2.0 + example[1] * 3 + 0.5],
    84                       dtype=np.float32)
    85          ] for example in TWO_FEATURES_EXAMPLES])
    86  ]
    87  
    88  
    89  def _compare_prediction_result(a, b):
    90    return ((a.example == b.example).all() and all(
    91        np.array_equal(actual, expected) for actual,
    92        expected in zip(a.inference, b.inference)))
    93  
    94  
    95  def _assign_or_fail(args):
    96    """CUDA error checking."""
    97    from cuda import cuda
    98    err, ret = args[0], args[1:]
    99    if isinstance(err, cuda.CUresult):
   100      if err != cuda.CUresult.CUDA_SUCCESS:
   101        raise RuntimeError("Cuda Error: {}".format(err))
   102    else:
   103      raise RuntimeError("Unknown error type: {}".format(err))
   104    # Special case so that no unpacking is needed at call-site.
   105    if len(ret) == 1:
   106      return ret[0]
   107    return ret
   108  
   109  
   110  def _custom_tensorRT_inference_fn(batch, engine, inference_args):
   111    from cuda import cuda
   112    (
   113        engine,
   114        context,
   115        context_lock,
   116        inputs,
   117        outputs,
   118        gpu_allocations,
   119        cpu_allocations,
   120        stream) = engine.get_engine_attrs()
   121  
   122    # Process I/O and execute the network
   123    with context_lock:
   124      _assign_or_fail(
   125          cuda.cuMemcpyHtoDAsync(
   126              inputs[0]['allocation'],
   127              np.ascontiguousarray(batch),
   128              inputs[0]['size'],
   129              stream))
   130      context.execute_async_v2(gpu_allocations, stream)
   131      for output in range(len(cpu_allocations)):
   132        _assign_or_fail(
   133            cuda.cuMemcpyDtoHAsync(
   134                cpu_allocations[output],
   135                outputs[output]['allocation'],
   136                outputs[output]['size'],
   137                stream))
   138      _assign_or_fail(cuda.cuStreamSynchronize(stream))
   139  
   140      return [
   141          PredictionResult(
   142              x, [prediction[idx] * 2 for prediction in cpu_allocations]) for idx,
   143          x in enumerate(batch)
   144      ]
   145  
   146  
   147  @pytest.mark.uses_tensorrt
   148  class TensorRTRunInferenceTest(unittest.TestCase):
   149    @unittest.skipIf(GCSFileSystem is None, 'GCP dependencies are not installed')
   150    def test_inference_single_tensor_feature_onnx(self):
   151      """
   152      This tests ONNX parser and TensorRT engine creation from parsed ONNX
   153      network. Single feature tensors batched into size of 4 are used as input.
   154      """
   155      inference_runner = TensorRTEngineHandlerNumPy(
   156          min_batch_size=4,
   157          max_batch_size=4,
   158          onnx_path="gs://apache-beam-ml/models/single_tensor_features_model.onnx"
   159      )
   160      network, builder = inference_runner.load_onnx()
   161      engine = inference_runner.build_engine(network, builder)
   162      predictions = inference_runner.run_inference(
   163          SINGLE_FEATURE_EXAMPLES, engine)
   164      for actual, expected in zip(predictions, SINGLE_FEATURE_PREDICTIONS):
   165        self.assertEqual(actual, expected)
   166  
   167    @unittest.skipIf(GCSFileSystem is None, 'GCP dependencies are not installed')
   168    def test_inference_multiple_tensor_features_onnx(self):
   169      """
   170      This tests ONNX parser and TensorRT engine creation from parsed ONNX
   171      network. Two feature tensors batched into size of 4 are used as input.
   172      """
   173      inference_runner = TensorRTEngineHandlerNumPy(
   174          min_batch_size=4,
   175          max_batch_size=4,
   176          onnx_path=
   177          'gs://apache-beam-ml/models/multiple_tensor_features_model.onnx')
   178      network, builder = inference_runner.load_onnx()
   179      engine = inference_runner.build_engine(network, builder)
   180      predictions = inference_runner.run_inference(TWO_FEATURES_EXAMPLES, engine)
   181      for actual, expected in zip(predictions, TWO_FEATURES_PREDICTIONS):
   182        self.assertTrue(_compare_prediction_result(actual, expected))
   183  
   184    def test_inference_single_tensor_feature(self):
   185      """
   186      This tests creating TensorRT network from scratch. Test replicates the same
   187      ONNX network above but natively in TensorRT. After network creation, network
   188      is used to build a TensorRT engine. Single feature tensors batched into size
   189      of 4 are used as input.
   190      """
   191      inference_runner = TensorRTEngineHandlerNumPy(
   192          min_batch_size=4, max_batch_size=4)
   193      builder = trt.Builder(LOGGER)
   194      network = builder.create_network(
   195          flags=1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
   196      input_tensor = network.add_input(
   197          name="input", dtype=trt.float32, shape=(4, 1))
   198      weight_const = network.add_constant(
   199          (1, 1), trt.Weights((np.ascontiguousarray([2.0], dtype=np.float32))))
   200      mm = network.add_matrix_multiply(
   201          input_tensor,
   202          trt.MatrixOperation.NONE,
   203          weight_const.get_output(0),
   204          trt.MatrixOperation.NONE)
   205      bias_const = network.add_constant(
   206          (1, 1), trt.Weights((np.ascontiguousarray([0.5], dtype=np.float32))))
   207      bias_add = network.add_elementwise(
   208          mm.get_output(0),
   209          bias_const.get_output(0),
   210          trt.ElementWiseOperation.SUM)
   211      bias_add.get_output(0).name = "output"
   212      network.mark_output(tensor=bias_add.get_output(0))
   213  
   214      engine = inference_runner.build_engine(network, builder)
   215      predictions = inference_runner.run_inference(
   216          SINGLE_FEATURE_EXAMPLES, engine)
   217      for actual, expected in zip(predictions, SINGLE_FEATURE_PREDICTIONS):
   218        self.assertEqual(actual, expected)
   219  
   220    def test_inference_custom_single_tensor_feature(self):
   221      """
   222      This tests creating TensorRT network from scratch. Test replicates the same
   223      ONNX network above but natively in TensorRT. After network creation, network
   224      is used to build a TensorRT engine. Single feature tensors batched into size
   225      of 4 are used as input. This routes through a custom inference function.
   226      """
   227      inference_runner = TensorRTEngineHandlerNumPy(
   228          min_batch_size=4,
   229          max_batch_size=4,
   230          inference_fn=_custom_tensorRT_inference_fn)
   231      builder = trt.Builder(LOGGER)
   232      network = builder.create_network(
   233          flags=1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
   234      input_tensor = network.add_input(
   235          name="input", dtype=trt.float32, shape=(4, 1))
   236      weight_const = network.add_constant(
   237          (1, 1), trt.Weights((np.ascontiguousarray([2.0], dtype=np.float32))))
   238      mm = network.add_matrix_multiply(
   239          input_tensor,
   240          trt.MatrixOperation.NONE,
   241          weight_const.get_output(0),
   242          trt.MatrixOperation.NONE)
   243      bias_const = network.add_constant(
   244          (1, 1), trt.Weights((np.ascontiguousarray([0.5], dtype=np.float32))))
   245      bias_add = network.add_elementwise(
   246          mm.get_output(0),
   247          bias_const.get_output(0),
   248          trt.ElementWiseOperation.SUM)
   249      bias_add.get_output(0).name = "output"
   250      network.mark_output(tensor=bias_add.get_output(0))
   251  
   252      engine = inference_runner.build_engine(network, builder)
   253      predictions = inference_runner.run_inference(
   254          SINGLE_FEATURE_EXAMPLES, engine)
   255      for actual, expected in zip(predictions, SINGLE_FEATURE_CUSTOM_PREDICTIONS):
   256        self.assertEqual(actual, expected)
   257  
   258    def test_inference_multiple_tensor_features(self):
   259      """
   260      This tests creating TensorRT network from scratch. Test replicates the same
   261      ONNX network above but natively in TensorRT. After network creation, network
   262      is used to build a TensorRT engine. Two feature tensors batched into size of
   263      4 are used as input.
   264      """
   265      inference_runner = TensorRTEngineHandlerNumPy(
   266          min_batch_size=4, max_batch_size=4)
   267      builder = trt.Builder(LOGGER)
   268      network = builder.create_network(
   269          flags=1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
   270      input_tensor = network.add_input(
   271          name="input", dtype=trt.float32, shape=(4, 2))
   272      weight_const = network.add_constant(
   273          (1, 2), trt.Weights((np.ascontiguousarray([2.0, 3], dtype=np.float32))))
   274      mm = network.add_matrix_multiply(
   275          input_tensor,
   276          trt.MatrixOperation.NONE,
   277          weight_const.get_output(0),
   278          trt.MatrixOperation.TRANSPOSE)
   279      bias_const = network.add_constant(
   280          (1, 1), trt.Weights((np.ascontiguousarray([0.5], dtype=np.float32))))
   281      bias_add = network.add_elementwise(
   282          mm.get_output(0),
   283          bias_const.get_output(0),
   284          trt.ElementWiseOperation.SUM)
   285      bias_add.get_output(0).name = "output"
   286      network.mark_output(tensor=bias_add.get_output(0))
   287  
   288      engine = inference_runner.build_engine(network, builder)
   289      predictions = inference_runner.run_inference(TWO_FEATURES_EXAMPLES, engine)
   290      for actual, expected in zip(predictions, TWO_FEATURES_PREDICTIONS):
   291        self.assertTrue(_compare_prediction_result(actual, expected))
   292  
   293    @unittest.skipIf(GCSFileSystem is None, 'GCP dependencies are not installed')
   294    def test_inference_single_tensor_feature_built_engine(self):
   295      """
   296      This tests already pre-built TensorRT engine from ONNX network. To execute
   297      this test succesfully, TensorRT engine that is used here, must have been
   298      built in the same environment with the same GPU that will be used when
   299      running a test. In other words, using the same environment and same GPU we
   300      must pre-build the engine and after we run this test. Otherwise behavior
   301      might be unpredictable, read more:
   302      https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#compatibility-serialized-engines
   303      Single feature tensors batched into size of 4 are used as input.
   304      """
   305      inference_runner = TensorRTEngineHandlerNumPy(
   306          min_batch_size=4,
   307          max_batch_size=4,
   308          engine_path=
   309          'gs://apache-beam-ml/models/single_tensor_features_engine.trt')
   310      engine = inference_runner.load_model()
   311      predictions = inference_runner.run_inference(
   312          SINGLE_FEATURE_EXAMPLES, engine)
   313      for actual, expected in zip(predictions, SINGLE_FEATURE_PREDICTIONS):
   314        self.assertEqual(actual, expected)
   315  
   316    @unittest.skipIf(GCSFileSystem is None, 'GCP dependencies are not installed')
   317    def test_inference_multiple_tensor_feature_built_engine(self):
   318      """
   319      This tests already pre-built TensorRT engine from ONNX network. To execute
   320      this test succesfully, TensorRT engine that is used here, must have been
   321      built in the same environment with the same GPU that will be used when
   322      running a test. In other words, using the same environment and same GPU we
   323      must pre-build the engine and after we run this test. Otherwise behavior
   324      might be unpredictable, read more:
   325      https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#compatibility-serialized-engines
   326      Two feature tensors batched into size of 4 are used as input.
   327      """
   328      inference_runner = TensorRTEngineHandlerNumPy(
   329          min_batch_size=4,
   330          max_batch_size=4,
   331          engine_path=
   332          'gs://apache-beam-ml/models/multiple_tensor_features_engine.trt')
   333      engine = inference_runner.load_model()
   334      predictions = inference_runner.run_inference(TWO_FEATURES_EXAMPLES, engine)
   335      for actual, expected in zip(predictions, TWO_FEATURES_PREDICTIONS):
   336        self.assertTrue(_compare_prediction_result(actual, expected))
   337  
   338    def test_num_bytes(self):
   339      inference_runner = TensorRTEngineHandlerNumPy(
   340          min_batch_size=1, max_batch_size=1)
   341      examples = [
   342          np.array([1, 5], dtype=np.float32),
   343          np.array([3, 10], dtype=np.float32),
   344          np.array([-14, 0], dtype=np.float32),
   345          np.array([0.5, 0.5], dtype=np.float32)
   346      ]
   347      self.assertEqual((examples[0].itemsize) * 4,
   348                       inference_runner.get_num_bytes(examples))
   349  
   350    def test_namespace(self):
   351      inference_runner = TensorRTEngineHandlerNumPy(
   352          min_batch_size=4, max_batch_size=4)
   353      self.assertEqual(
   354          'RunInferenceTensorRT', inference_runner.get_metrics_namespace())
   355  
   356  
   357  @pytest.mark.uses_tensorrt
   358  class TensorRTRunInferencePipelineTest(unittest.TestCase):
   359    @unittest.skipIf(GCSFileSystem is None, 'GCP dependencies are not installed')
   360    def test_pipeline_single_tensor_feature_built_engine(self):
   361      with TestPipeline() as pipeline:
   362        engine_handler = TensorRTEngineHandlerNumPy(
   363            min_batch_size=4,
   364            max_batch_size=4,
   365            engine_path=
   366            'gs://apache-beam-ml/models/single_tensor_features_engine.trt')
   367        pcoll = pipeline | 'start' >> beam.Create(SINGLE_FEATURE_EXAMPLES)
   368        predictions = pcoll | RunInference(engine_handler)
   369        assert_that(
   370            predictions,
   371            equal_to(
   372                SINGLE_FEATURE_PREDICTIONS, equals_fn=_compare_prediction_result))
   373  
   374    @unittest.skipIf(GCSFileSystem is None, 'GCP dependencies are not installed')
   375    def test_pipeline_sets_env_vars_correctly(self):
   376      with TestPipeline() as pipeline:
   377        engine_handler = TensorRTEngineHandlerNumPy(
   378            env_vars={'FOO': 'bar'},
   379            min_batch_size=4,
   380            max_batch_size=4,
   381            engine_path=
   382            'gs://apache-beam-ml/models/single_tensor_features_engine.trt')
   383        os.environ.pop('FOO', None)
   384        self.assertFalse('FOO' in os.environ)
   385        _ = (
   386            pipeline
   387            | 'start' >> beam.Create(SINGLE_FEATURE_EXAMPLES)
   388            | RunInference(engine_handler))
   389        pipeline.run()
   390        self.assertTrue('FOO' in os.environ)
   391        self.assertTrue((os.environ['FOO']) == 'bar')
   392  
   393    @unittest.skipIf(GCSFileSystem is None, 'GCP dependencies are not installed')
   394    def test_pipeline_multiple_tensor_feature_built_engine(self):
   395      with TestPipeline() as pipeline:
   396        engine_handler = TensorRTEngineHandlerNumPy(
   397            min_batch_size=4,
   398            max_batch_size=4,
   399            engine_path=
   400            'gs://apache-beam-ml/models/multiple_tensor_features_engine.trt')
   401        pcoll = pipeline | 'start' >> beam.Create(TWO_FEATURES_EXAMPLES)
   402        predictions = pcoll | RunInference(engine_handler)
   403        assert_that(
   404            predictions,
   405            equal_to(
   406                TWO_FEATURES_PREDICTIONS, equals_fn=_compare_prediction_result))
   407  
   408  
   409  if __name__ == '__main__':
   410    unittest.main()