github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/tensorrt/tensorrt_configure.bzl (about)

     1  # -*- Python -*-
     2  """Repository rule for TensorRT configuration.
     3  
     4  `tensorrt_configure` depends on the following environment variables:
     5  
     6    * `TF_TENSORRT_VERSION`: The TensorRT libnvinfer version.
     7    * `TENSORRT_INSTALL_PATH`: The installation path of the TensorRT library.
     8  """
     9  
    10  load(
    11      "//third_party/gpus:cuda_configure.bzl",
    12      "find_cuda_config",
    13      "get_cpu_value",
    14      "lib_name",
    15      "make_copy_files_rule",
    16  )
    17  
    18  _TENSORRT_INSTALL_PATH = "TENSORRT_INSTALL_PATH"
    19  _TF_TENSORRT_CONFIG_REPO = "TF_TENSORRT_CONFIG_REPO"
    20  _TF_TENSORRT_VERSION = "TF_TENSORRT_VERSION"
    21  _TF_NEED_TENSORRT = "TF_NEED_TENSORRT"
    22  
    23  _TF_TENSORRT_LIBS = ["nvinfer", "nvinfer_plugin"]
    24  _TF_TENSORRT_HEADERS = ["NvInfer.h", "NvUtils.h", "NvInferPlugin.h"]
    25  _TF_TENSORRT_HEADERS_V6 = [
    26      "NvInfer.h",
    27      "NvUtils.h",
    28      "NvInferPlugin.h",
    29      "NvInferVersion.h",
    30      "NvInferRTSafe.h",
    31      "NvInferRTExt.h",
    32      "NvInferPluginUtils.h",
    33  ]
    34  
    35  _DEFINE_TENSORRT_SONAME_MAJOR = "#define NV_TENSORRT_SONAME_MAJOR"
    36  _DEFINE_TENSORRT_SONAME_MINOR = "#define NV_TENSORRT_SONAME_MINOR"
    37  _DEFINE_TENSORRT_SONAME_PATCH = "#define NV_TENSORRT_SONAME_PATCH"
    38  
    39  def _at_least_version(actual_version, required_version):
    40      actual = [int(v) for v in actual_version.split(".")]
    41      required = [int(v) for v in required_version.split(".")]
    42      return actual >= required
    43  
    44  def _get_tensorrt_headers(tensorrt_version):
    45      if _at_least_version(tensorrt_version, "6"):
    46          return _TF_TENSORRT_HEADERS_V6
    47      return _TF_TENSORRT_HEADERS
    48  
    49  def _tpl(repository_ctx, tpl, substitutions):
    50      repository_ctx.template(
    51          tpl,
    52          Label("//third_party/tensorrt:%s.tpl" % tpl),
    53          substitutions,
    54      )
    55  
    56  def _create_dummy_repository(repository_ctx):
    57      """Create a dummy TensorRT repository."""
    58      _tpl(repository_ctx, "build_defs.bzl", {"%{if_tensorrt}": "if_false"})
    59      _tpl(repository_ctx, "BUILD", {
    60          "%{copy_rules}": "",
    61          "\":tensorrt_include\"": "",
    62          "\":tensorrt_lib\"": "",
    63      })
    64      _tpl(repository_ctx, "tensorrt/include/tensorrt_config.h", {
    65          "%{tensorrt_version}": "",
    66      })
    67  
    68  def enable_tensorrt(repository_ctx):
    69      """Returns whether to build with TensorRT support."""
    70      return int(repository_ctx.os.environ.get(_TF_NEED_TENSORRT, False))
    71  
    72  def _tensorrt_configure_impl(repository_ctx):
    73      """Implementation of the tensorrt_configure repository rule."""
    74      if _TF_TENSORRT_CONFIG_REPO in repository_ctx.os.environ:
    75          # Forward to the pre-configured remote repository.
    76          remote_config_repo = repository_ctx.os.environ[_TF_TENSORRT_CONFIG_REPO]
    77          repository_ctx.template("BUILD", Label(remote_config_repo + ":BUILD"), {})
    78          repository_ctx.template(
    79              "build_defs.bzl",
    80              Label(remote_config_repo + ":build_defs.bzl"),
    81              {},
    82          )
    83          repository_ctx.template(
    84              "tensorrt/include/tensorrt_config.h",
    85              Label(remote_config_repo + ":tensorrt/include/tensorrt_config.h"),
    86              {},
    87          )
    88          repository_ctx.template(
    89              "LICENSE",
    90              Label(remote_config_repo + ":LICENSE"),
    91              {},
    92          )
    93          return
    94  
    95      # Copy license file in non-remote build.
    96      repository_ctx.template(
    97          "LICENSE",
    98          Label("//third_party/tensorrt:LICENSE"),
    99          {},
   100      )
   101  
   102      if not enable_tensorrt(repository_ctx):
   103          _create_dummy_repository(repository_ctx)
   104          return
   105  
   106      config = find_cuda_config(repository_ctx, ["tensorrt"])
   107      trt_version = config["tensorrt_version"]
   108      cpu_value = get_cpu_value(repository_ctx)
   109  
   110      # Copy the library and header files.
   111      libraries = [lib_name(lib, cpu_value, trt_version) for lib in _TF_TENSORRT_LIBS]
   112      library_dir = config["tensorrt_library_dir"] + "/"
   113      headers = _get_tensorrt_headers(trt_version)
   114      include_dir = config["tensorrt_include_dir"] + "/"
   115      copy_rules = [
   116          make_copy_files_rule(
   117              repository_ctx,
   118              name = "tensorrt_lib",
   119              srcs = [library_dir + library for library in libraries],
   120              outs = ["tensorrt/lib/" + library for library in libraries],
   121          ),
   122          make_copy_files_rule(
   123              repository_ctx,
   124              name = "tensorrt_include",
   125              srcs = [include_dir + header for header in headers],
   126              outs = ["tensorrt/include/" + header for header in headers],
   127          ),
   128      ]
   129  
   130      # Set up config file.
   131      _tpl(repository_ctx, "build_defs.bzl", {"%{if_tensorrt}": "if_true"})
   132  
   133      # Set up BUILD file.
   134      _tpl(repository_ctx, "BUILD", {
   135          "%{copy_rules}": "\n".join(copy_rules),
   136      })
   137  
   138      # Set up tensorrt_config.h, which is used by
   139      # tensorflow/stream_executor/dso_loader.cc.
   140      _tpl(repository_ctx, "tensorrt/include/tensorrt_config.h", {
   141          "%{tensorrt_version}": trt_version,
   142      })
   143  
   144  tensorrt_configure = repository_rule(
   145      implementation = _tensorrt_configure_impl,
   146      environ = [
   147          _TENSORRT_INSTALL_PATH,
   148          _TF_TENSORRT_VERSION,
   149          _TF_TENSORRT_CONFIG_REPO,
   150          _TF_NEED_TENSORRT,
   151          "TF_CUDA_PATHS",
   152      ],
   153  )
   154  """Detects and configures the local CUDA toolchain.
   155  
   156  Add the following to your WORKSPACE FILE:
   157  
   158  ```python
   159  tensorrt_configure(name = "local_config_tensorrt")
   160  ```
   161  
   162  Args:
   163    name: A unique name for this workspace rule.
   164  """