github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/tensorrt/tensorrt_configure.bzl (about) 1 # -*- Python -*- 2 """Repository rule for TensorRT configuration. 3 4 `tensorrt_configure` depends on the following environment variables: 5 6 * `TF_TENSORRT_VERSION`: The TensorRT libnvinfer version. 7 * `TENSORRT_INSTALL_PATH`: The installation path of the TensorRT library. 8 """ 9 10 load( 11 "//third_party/gpus:cuda_configure.bzl", 12 "find_cuda_config", 13 "get_cpu_value", 14 "lib_name", 15 "make_copy_files_rule", 16 ) 17 18 _TENSORRT_INSTALL_PATH = "TENSORRT_INSTALL_PATH" 19 _TF_TENSORRT_CONFIG_REPO = "TF_TENSORRT_CONFIG_REPO" 20 _TF_TENSORRT_VERSION = "TF_TENSORRT_VERSION" 21 _TF_NEED_TENSORRT = "TF_NEED_TENSORRT" 22 23 _TF_TENSORRT_LIBS = ["nvinfer", "nvinfer_plugin"] 24 _TF_TENSORRT_HEADERS = ["NvInfer.h", "NvUtils.h", "NvInferPlugin.h"] 25 _TF_TENSORRT_HEADERS_V6 = [ 26 "NvInfer.h", 27 "NvUtils.h", 28 "NvInferPlugin.h", 29 "NvInferVersion.h", 30 "NvInferRTSafe.h", 31 "NvInferRTExt.h", 32 "NvInferPluginUtils.h", 33 ] 34 35 _DEFINE_TENSORRT_SONAME_MAJOR = "#define NV_TENSORRT_SONAME_MAJOR" 36 _DEFINE_TENSORRT_SONAME_MINOR = "#define NV_TENSORRT_SONAME_MINOR" 37 _DEFINE_TENSORRT_SONAME_PATCH = "#define NV_TENSORRT_SONAME_PATCH" 38 39 def _at_least_version(actual_version, required_version): 40 actual = [int(v) for v in actual_version.split(".")] 41 required = [int(v) for v in required_version.split(".")] 42 return actual >= required 43 44 def _get_tensorrt_headers(tensorrt_version): 45 if _at_least_version(tensorrt_version, "6"): 46 return _TF_TENSORRT_HEADERS_V6 47 return _TF_TENSORRT_HEADERS 48 49 def _tpl(repository_ctx, tpl, substitutions): 50 repository_ctx.template( 51 tpl, 52 Label("//third_party/tensorrt:%s.tpl" % tpl), 53 substitutions, 54 ) 55 56 def _create_dummy_repository(repository_ctx): 57 """Create a dummy TensorRT repository.""" 58 _tpl(repository_ctx, "build_defs.bzl", {"%{if_tensorrt}": "if_false"}) 59 _tpl(repository_ctx, "BUILD", { 60 "%{copy_rules}": "", 61 "\":tensorrt_include\"": "", 62 "\":tensorrt_lib\"": "", 63 }) 64 _tpl(repository_ctx, "tensorrt/include/tensorrt_config.h", { 65 "%{tensorrt_version}": "", 66 }) 67 68 def enable_tensorrt(repository_ctx): 69 """Returns whether to build with TensorRT support.""" 70 return int(repository_ctx.os.environ.get(_TF_NEED_TENSORRT, False)) 71 72 def _tensorrt_configure_impl(repository_ctx): 73 """Implementation of the tensorrt_configure repository rule.""" 74 if _TF_TENSORRT_CONFIG_REPO in repository_ctx.os.environ: 75 # Forward to the pre-configured remote repository. 76 remote_config_repo = repository_ctx.os.environ[_TF_TENSORRT_CONFIG_REPO] 77 repository_ctx.template("BUILD", Label(remote_config_repo + ":BUILD"), {}) 78 repository_ctx.template( 79 "build_defs.bzl", 80 Label(remote_config_repo + ":build_defs.bzl"), 81 {}, 82 ) 83 repository_ctx.template( 84 "tensorrt/include/tensorrt_config.h", 85 Label(remote_config_repo + ":tensorrt/include/tensorrt_config.h"), 86 {}, 87 ) 88 repository_ctx.template( 89 "LICENSE", 90 Label(remote_config_repo + ":LICENSE"), 91 {}, 92 ) 93 return 94 95 # Copy license file in non-remote build. 96 repository_ctx.template( 97 "LICENSE", 98 Label("//third_party/tensorrt:LICENSE"), 99 {}, 100 ) 101 102 if not enable_tensorrt(repository_ctx): 103 _create_dummy_repository(repository_ctx) 104 return 105 106 config = find_cuda_config(repository_ctx, ["tensorrt"]) 107 trt_version = config["tensorrt_version"] 108 cpu_value = get_cpu_value(repository_ctx) 109 110 # Copy the library and header files. 111 libraries = [lib_name(lib, cpu_value, trt_version) for lib in _TF_TENSORRT_LIBS] 112 library_dir = config["tensorrt_library_dir"] + "/" 113 headers = _get_tensorrt_headers(trt_version) 114 include_dir = config["tensorrt_include_dir"] + "/" 115 copy_rules = [ 116 make_copy_files_rule( 117 repository_ctx, 118 name = "tensorrt_lib", 119 srcs = [library_dir + library for library in libraries], 120 outs = ["tensorrt/lib/" + library for library in libraries], 121 ), 122 make_copy_files_rule( 123 repository_ctx, 124 name = "tensorrt_include", 125 srcs = [include_dir + header for header in headers], 126 outs = ["tensorrt/include/" + header for header in headers], 127 ), 128 ] 129 130 # Set up config file. 131 _tpl(repository_ctx, "build_defs.bzl", {"%{if_tensorrt}": "if_true"}) 132 133 # Set up BUILD file. 134 _tpl(repository_ctx, "BUILD", { 135 "%{copy_rules}": "\n".join(copy_rules), 136 }) 137 138 # Set up tensorrt_config.h, which is used by 139 # tensorflow/stream_executor/dso_loader.cc. 140 _tpl(repository_ctx, "tensorrt/include/tensorrt_config.h", { 141 "%{tensorrt_version}": trt_version, 142 }) 143 144 tensorrt_configure = repository_rule( 145 implementation = _tensorrt_configure_impl, 146 environ = [ 147 _TENSORRT_INSTALL_PATH, 148 _TF_TENSORRT_VERSION, 149 _TF_TENSORRT_CONFIG_REPO, 150 _TF_NEED_TENSORRT, 151 "TF_CUDA_PATHS", 152 ], 153 ) 154 """Detects and configures the local CUDA toolchain. 155 156 Add the following to your WORKSPACE FILE: 157 158 ```python 159 tensorrt_configure(name = "local_config_tensorrt") 160 ``` 161 162 Args: 163 name: A unique name for this workspace rule. 164 """