github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/gpus/cuda_configure.bzl (about)

     1  # -*- Python -*-
     2  """Repository rule for CUDA autoconfiguration.
     3  
     4  `cuda_configure` depends on the following environment variables:
     5  
     6    * `TF_NEED_CUDA`: Whether to enable building with CUDA.
     7    * `GCC_HOST_COMPILER_PATH`: The GCC host compiler path
     8    * `TF_CUDA_CLANG`: Whether to use clang as a cuda compiler.
     9    * `CLANG_CUDA_COMPILER_PATH`: The clang compiler path that will be used for
    10      both host and device code compilation if TF_CUDA_CLANG is 1.
    11    * `TF_DOWNLOAD_CLANG`: Whether to download a recent release of clang
    12      compiler and use it to build tensorflow. When this option is set
    13      CLANG_CUDA_COMPILER_PATH is ignored.
    14    * `TF_CUDA_PATHS`: The base paths to look for CUDA and cuDNN. Default is
    15      `/usr/local/cuda,usr/`.
    16    * `CUDA_TOOLKIT_PATH` (deprecated): The path to the CUDA toolkit. Default is
    17      `/usr/local/cuda`.
    18    * `TF_CUDA_VERSION`: The version of the CUDA toolkit. If this is blank, then
    19      use the system default.
    20    * `TF_CUDNN_VERSION`: The version of the cuDNN library.
    21    * `CUDNN_INSTALL_PATH` (deprecated): The path to the cuDNN library. Default is
    22      `/usr/local/cuda`.
    23    * `TF_CUDA_COMPUTE_CAPABILITIES`: The CUDA compute capabilities. Default is
    24      `3.5,5.2`.
    25    * `PYTHON_BIN_PATH`: The python binary path
    26  """
    27  
    28  load("//third_party/clang_toolchain:download_clang.bzl", "download_clang")
    29  load(
    30      "@bazel_tools//tools/cpp:lib_cc_configure.bzl",
    31      "escape_string",
    32      "get_env_var",
    33  )
    34  load(
    35      "@bazel_tools//tools/cpp:windows_cc_configure.bzl",
    36      "find_msvc_tool",
    37      "find_vc_path",
    38      "setup_vc_env_vars",
    39  )
    40  
    41  _GCC_HOST_COMPILER_PATH = "GCC_HOST_COMPILER_PATH"
    42  _GCC_HOST_COMPILER_PREFIX = "GCC_HOST_COMPILER_PREFIX"
    43  _CLANG_CUDA_COMPILER_PATH = "CLANG_CUDA_COMPILER_PATH"
    44  _CUDA_TOOLKIT_PATH = "CUDA_TOOLKIT_PATH"
    45  _TF_CUDA_VERSION = "TF_CUDA_VERSION"
    46  _TF_CUDNN_VERSION = "TF_CUDNN_VERSION"
    47  _CUDNN_INSTALL_PATH = "CUDNN_INSTALL_PATH"
    48  _TF_CUDA_COMPUTE_CAPABILITIES = "TF_CUDA_COMPUTE_CAPABILITIES"
    49  _TF_CUDA_CONFIG_REPO = "TF_CUDA_CONFIG_REPO"
    50  _TF_DOWNLOAD_CLANG = "TF_DOWNLOAD_CLANG"
    51  _PYTHON_BIN_PATH = "PYTHON_BIN_PATH"
    52  
    53  _DEFAULT_CUDA_COMPUTE_CAPABILITIES = ["3.5", "5.2"]
    54  
    55  def to_list_of_strings(elements):
    56      """Convert the list of ["a", "b", "c"] into '"a", "b", "c"'.
    57  
    58      This is to be used to put a list of strings into the bzl file templates
    59      so it gets interpreted as list of strings in Starlark.
    60  
    61      Args:
    62        elements: list of string elements
    63  
    64      Returns:
    65        single string of elements wrapped in quotes separated by a comma."""
    66      quoted_strings = ["\"" + element + "\"" for element in elements]
    67      return ", ".join(quoted_strings)
    68  
    69  def verify_build_defines(params):
    70      """Verify all variables that crosstool/BUILD.tpl expects are substituted.
    71  
    72      Args:
    73        params: dict of variables that will be passed to the BUILD.tpl template.
    74      """
    75      missing = []
    76      for param in [
    77          "cxx_builtin_include_directories",
    78          "extra_no_canonical_prefixes_flags",
    79          "host_compiler_path",
    80          "host_compiler_prefix",
    81          "host_compiler_warnings",
    82          "linker_bin_path",
    83          "linker_files",
    84          "msvc_cl_path",
    85          "msvc_env_include",
    86          "msvc_env_lib",
    87          "msvc_env_path",
    88          "msvc_env_tmp",
    89          "msvc_lib_path",
    90          "msvc_link_path",
    91          "msvc_ml_path",
    92          "unfiltered_compile_flags",
    93          "win_linker_files",
    94      ]:
    95          if ("%{" + param + "}") not in params:
    96              missing.append(param)
    97  
    98      if missing:
    99          auto_configure_fail(
   100              "BUILD.tpl template is missing these variables: " +
   101              str(missing) +
   102              ".\nWe only got: " +
   103              str(params) +
   104              ".",
   105          )
   106  
   107  def _get_python_bin(repository_ctx):
   108      """Gets the python bin path."""
   109      python_bin = repository_ctx.os.environ.get(_PYTHON_BIN_PATH)
   110      if python_bin != None:
   111          return python_bin
   112      python_bin_name = "python.exe" if _is_windows(repository_ctx) else "python"
   113      python_bin_path = repository_ctx.which(python_bin_name)
   114      if python_bin_path != None:
   115          return str(python_bin_path)
   116      auto_configure_fail(
   117          "Cannot find python in PATH, please make sure " +
   118          "python is installed and add its directory in PATH, or --define " +
   119          "%s='/something/else'.\nPATH=%s" % (
   120              _PYTHON_BIN_PATH,
   121              repository_ctx.os.environ.get("PATH", ""),
   122          ),
   123      )
   124  
   125  def _get_nvcc_tmp_dir_for_windows(repository_ctx):
   126      """Return the Windows tmp directory for nvcc to generate intermediate source files."""
   127      escaped_tmp_dir = escape_string(
   128          get_env_var(repository_ctx, "TMP", "C:\\Windows\\Temp").replace(
   129              "\\",
   130              "\\\\",
   131          ),
   132      )
   133      return escaped_tmp_dir + "\\\\nvcc_inter_files_tmp_dir"
   134  
   135  def _get_nvcc_tmp_dir_for_unix(repository_ctx):
   136      """Return the UNIX tmp directory for nvcc to generate intermediate source files."""
   137      escaped_tmp_dir = escape_string(
   138          get_env_var(repository_ctx, "TMPDIR", "/tmp"),
   139      )
   140      return escaped_tmp_dir + "/nvcc_inter_files_tmp_dir"
   141  
   142  def _get_msvc_compiler(repository_ctx):
   143      vc_path = find_vc_path(repository_ctx)
   144      return find_msvc_tool(repository_ctx, vc_path, "cl.exe").replace("\\", "/")
   145  
   146  def _get_win_cuda_defines(repository_ctx):
   147      """Return CROSSTOOL defines for Windows"""
   148  
   149      # If we are not on Windows, return fake vaules for Windows specific fields.
   150      # This ensures the CROSSTOOL file parser is happy.
   151      if not _is_windows(repository_ctx):
   152          return {
   153              "%{msvc_env_tmp}": "msvc_not_used",
   154              "%{msvc_env_path}": "msvc_not_used",
   155              "%{msvc_env_include}": "msvc_not_used",
   156              "%{msvc_env_lib}": "msvc_not_used",
   157              "%{msvc_cl_path}": "msvc_not_used",
   158              "%{msvc_ml_path}": "msvc_not_used",
   159              "%{msvc_link_path}": "msvc_not_used",
   160              "%{msvc_lib_path}": "msvc_not_used",
   161          }
   162  
   163      vc_path = find_vc_path(repository_ctx)
   164      if not vc_path:
   165          auto_configure_fail(
   166              "Visual C++ build tools not found on your machine." +
   167              "Please check your installation following https://docs.bazel.build/versions/master/windows.html#using",
   168          )
   169          return {}
   170  
   171      env = setup_vc_env_vars(repository_ctx, vc_path)
   172      escaped_paths = escape_string(env["PATH"])
   173      escaped_include_paths = escape_string(env["INCLUDE"])
   174      escaped_lib_paths = escape_string(env["LIB"])
   175      escaped_tmp_dir = escape_string(
   176          get_env_var(repository_ctx, "TMP", "C:\\Windows\\Temp").replace(
   177              "\\",
   178              "\\\\",
   179          ),
   180      )
   181  
   182      msvc_cl_path = _get_python_bin(repository_ctx)
   183      msvc_ml_path = find_msvc_tool(repository_ctx, vc_path, "ml64.exe").replace(
   184          "\\",
   185          "/",
   186      )
   187      msvc_link_path = find_msvc_tool(repository_ctx, vc_path, "link.exe").replace(
   188          "\\",
   189          "/",
   190      )
   191      msvc_lib_path = find_msvc_tool(repository_ctx, vc_path, "lib.exe").replace(
   192          "\\",
   193          "/",
   194      )
   195  
   196      # nvcc will generate some temporary source files under %{nvcc_tmp_dir}
   197      # The generated files are guaranteed to have unique name, so they can share
   198      # the same tmp directory
   199      escaped_cxx_include_directories = [
   200          _get_nvcc_tmp_dir_for_windows(repository_ctx),
   201      ]
   202      for path in escaped_include_paths.split(";"):
   203          if path:
   204              escaped_cxx_include_directories.append(path)
   205  
   206      return {
   207          "%{msvc_env_tmp}": escaped_tmp_dir,
   208          "%{msvc_env_path}": escaped_paths,
   209          "%{msvc_env_include}": escaped_include_paths,
   210          "%{msvc_env_lib}": escaped_lib_paths,
   211          "%{msvc_cl_path}": msvc_cl_path,
   212          "%{msvc_ml_path}": msvc_ml_path,
   213          "%{msvc_link_path}": msvc_link_path,
   214          "%{msvc_lib_path}": msvc_lib_path,
   215          "%{cxx_builtin_include_directories}": to_list_of_strings(
   216              escaped_cxx_include_directories,
   217          ),
   218      }
   219  
   220  # TODO(dzc): Once these functions have been factored out of Bazel's
   221  # cc_configure.bzl, load them from @bazel_tools instead.
   222  # BEGIN cc_configure common functions.
   223  def find_cc(repository_ctx):
   224      """Find the C++ compiler."""
   225      if _is_windows(repository_ctx):
   226          return _get_msvc_compiler(repository_ctx)
   227  
   228      if _use_cuda_clang(repository_ctx):
   229          target_cc_name = "clang"
   230          cc_path_envvar = _CLANG_CUDA_COMPILER_PATH
   231          if _flag_enabled(repository_ctx, _TF_DOWNLOAD_CLANG):
   232              return "extra_tools/bin/clang"
   233      else:
   234          target_cc_name = "gcc"
   235          cc_path_envvar = _GCC_HOST_COMPILER_PATH
   236      cc_name = target_cc_name
   237  
   238      if cc_path_envvar in repository_ctx.os.environ:
   239          cc_name_from_env = repository_ctx.os.environ[cc_path_envvar].strip()
   240          if cc_name_from_env:
   241              cc_name = cc_name_from_env
   242      if cc_name.startswith("/"):
   243          # Absolute path, maybe we should make this supported by our which function.
   244          return cc_name
   245      cc = repository_ctx.which(cc_name)
   246      if cc == None:
   247          fail(("Cannot find {}, either correct your path or set the {}" +
   248                " environment variable").format(target_cc_name, cc_path_envvar))
   249      return cc
   250  
   251  _INC_DIR_MARKER_BEGIN = "#include <...>"
   252  
   253  # OSX add " (framework directory)" at the end of line, strip it.
   254  _OSX_FRAMEWORK_SUFFIX = " (framework directory)"
   255  _OSX_FRAMEWORK_SUFFIX_LEN = len(_OSX_FRAMEWORK_SUFFIX)
   256  
   257  def _cxx_inc_convert(path):
   258      """Convert path returned by cc -E xc++ in a complete path."""
   259      path = path.strip()
   260      if path.endswith(_OSX_FRAMEWORK_SUFFIX):
   261          path = path[:-_OSX_FRAMEWORK_SUFFIX_LEN].strip()
   262      return path
   263  
   264  def _normalize_include_path(repository_ctx, path):
   265      """Normalizes include paths before writing them to the crosstool.
   266  
   267        If path points inside the 'crosstool' folder of the repository, a relative
   268        path is returned.
   269        If path points outside the 'crosstool' folder, an absolute path is returned.
   270        """
   271      path = str(repository_ctx.path(path))
   272      crosstool_folder = str(repository_ctx.path(".").get_child("crosstool"))
   273  
   274      if path.startswith(crosstool_folder):
   275          # We drop the path to "$REPO/crosstool" and a trailing path separator.
   276          return path[len(crosstool_folder) + 1:]
   277      return path
   278  
   279  def _get_cxx_inc_directories_impl(repository_ctx, cc, lang_is_cpp):
   280      """Compute the list of default C or C++ include directories."""
   281      if lang_is_cpp:
   282          lang = "c++"
   283      else:
   284          lang = "c"
   285      result = repository_ctx.execute([cc, "-E", "-x" + lang, "-", "-v"])
   286      index1 = result.stderr.find(_INC_DIR_MARKER_BEGIN)
   287      if index1 == -1:
   288          return []
   289      index1 = result.stderr.find("\n", index1)
   290      if index1 == -1:
   291          return []
   292      index2 = result.stderr.rfind("\n ")
   293      if index2 == -1 or index2 < index1:
   294          return []
   295      index2 = result.stderr.find("\n", index2 + 1)
   296      if index2 == -1:
   297          inc_dirs = result.stderr[index1 + 1:]
   298      else:
   299          inc_dirs = result.stderr[index1 + 1:index2].strip()
   300  
   301      return [
   302          _normalize_include_path(repository_ctx, _cxx_inc_convert(p))
   303          for p in inc_dirs.split("\n")
   304      ]
   305  
   306  def get_cxx_inc_directories(repository_ctx, cc):
   307      """Compute the list of default C and C++ include directories."""
   308  
   309      # For some reason `clang -xc` sometimes returns include paths that are
   310      # different from the ones from `clang -xc++`. (Symlink and a dir)
   311      # So we run the compiler with both `-xc` and `-xc++` and merge resulting lists
   312      includes_cpp = _get_cxx_inc_directories_impl(repository_ctx, cc, True)
   313      includes_c = _get_cxx_inc_directories_impl(repository_ctx, cc, False)
   314  
   315      return includes_cpp + [
   316          inc
   317          for inc in includes_c
   318          if inc not in includes_cpp
   319      ]
   320  
   321  def auto_configure_fail(msg):
   322      """Output failure message when cuda configuration fails."""
   323      red = "\033[0;31m"
   324      no_color = "\033[0m"
   325      fail("\n%sCuda Configuration Error:%s %s\n" % (red, no_color, msg))
   326  
   327  # END cc_configure common functions (see TODO above).
   328  
   329  def _cuda_include_path(repository_ctx, cuda_config):
   330      """Generates the Starlark string with cuda include directories.
   331  
   332        Args:
   333          repository_ctx: The repository context.
   334          cc: The path to the gcc host compiler.
   335  
   336        Returns:
   337          A list of the gcc host compiler include directories.
   338        """
   339      nvcc_path = repository_ctx.path("%s/bin/nvcc%s" % (
   340          cuda_config.cuda_toolkit_path,
   341          ".exe" if cuda_config.cpu_value == "Windows" else "",
   342      ))
   343      result = repository_ctx.execute([
   344          nvcc_path,
   345          "-v",
   346          "/dev/null",
   347          "-o",
   348          "/dev/null",
   349      ])
   350      target_dir = ""
   351      for one_line in result.stderr.splitlines():
   352          if one_line.startswith("#$ _TARGET_DIR_="):
   353              target_dir = (
   354                  cuda_config.cuda_toolkit_path + "/" + one_line.replace(
   355                      "#$ _TARGET_DIR_=",
   356                      "",
   357                  ) + "/include"
   358              )
   359      inc_entries = []
   360      if target_dir != "":
   361          inc_entries.append(target_dir)
   362      inc_entries.append(cuda_config.cuda_toolkit_path + "/include")
   363      return inc_entries
   364  
   365  def enable_cuda(repository_ctx):
   366      """Returns whether to build with CUDA support."""
   367      return int(repository_ctx.os.environ.get("TF_NEED_CUDA", False))
   368  
   369  def matches_version(environ_version, detected_version):
   370      """Checks whether the user-specified version matches the detected version.
   371  
   372        This function performs a weak matching so that if the user specifies only
   373        the
   374        major or major and minor versions, the versions are still considered
   375        matching
   376        if the version parts match. To illustrate:
   377  
   378            environ_version  detected_version  result
   379            -----------------------------------------
   380            5.1.3            5.1.3             True
   381            5.1              5.1.3             True
   382            5                5.1               True
   383            5.1.3            5.1               False
   384            5.2.3            5.1.3             False
   385  
   386        Args:
   387          environ_version: The version specified by the user via environment
   388            variables.
   389          detected_version: The version autodetected from the CUDA installation on
   390            the system.
   391        Returns: True if user-specified version matches detected version and False
   392          otherwise.
   393      """
   394      environ_version_parts = environ_version.split(".")
   395      detected_version_parts = detected_version.split(".")
   396      if len(detected_version_parts) < len(environ_version_parts):
   397          return False
   398      for i, part in enumerate(detected_version_parts):
   399          if i >= len(environ_version_parts):
   400              break
   401          if part != environ_version_parts[i]:
   402              return False
   403      return True
   404  
   405  _NVCC_VERSION_PREFIX = "Cuda compilation tools, release "
   406  
   407  _DEFINE_CUDNN_MAJOR = "#define CUDNN_MAJOR"
   408  
   409  def find_cuda_define(repository_ctx, header_dir, header_file, define):
   410      """Returns the value of a #define in a header file.
   411  
   412        Greps through a header file and returns the value of the specified #define.
   413        If the #define is not found, then raise an error.
   414  
   415        Args:
   416          repository_ctx: The repository context.
   417          header_dir: The directory containing the header file.
   418          header_file: The header file name.
   419          define: The #define to search for.
   420  
   421        Returns:
   422          The value of the #define found in the header.
   423        """
   424  
   425      # Confirm location of the header and grep for the line defining the macro.
   426      h_path = repository_ctx.path("%s/%s" % (header_dir, header_file))
   427      if not h_path.exists:
   428          auto_configure_fail("Cannot find %s at %s" % (header_file, str(h_path)))
   429      result = repository_ctx.execute(
   430          # Grep one more lines as some #defines are splitted into two lines.
   431          [
   432              "grep",
   433              "--color=never",
   434              "-A1",
   435              "-E",
   436              define,
   437              str(h_path),
   438          ],
   439      )
   440      if result.stderr:
   441          auto_configure_fail("Error reading %s: %s" % (str(h_path), result.stderr))
   442  
   443      # Parse the version from the line defining the macro.
   444      if result.stdout.find(define) == -1:
   445          auto_configure_fail(
   446              "Cannot find line containing '%s' in %s" % (define, h_path),
   447          )
   448  
   449      # Split results to lines
   450      lines = result.stdout.split("\n")
   451      num_lines = len(lines)
   452      for l in range(num_lines):
   453          line = lines[l]
   454          if define in line:  # Find the line with define
   455              version = line
   456              if l != num_lines - 1 and line[-1] == "\\":  # Add next line, if multiline
   457                  version = version[:-1] + lines[l + 1]
   458              break
   459  
   460      # Remove any comments
   461      version = version.split("//")[0]
   462  
   463      # Remove define name
   464      version = version.replace(define, "").strip()
   465  
   466      # Remove the code after the version number.
   467      version_end = version.find(" ")
   468      if version_end != -1:
   469          if version_end == 0:
   470              auto_configure_fail(
   471                  "Cannot extract the version from line containing '%s' in %s" %
   472                  (define, str(h_path)),
   473              )
   474          version = version[:version_end].strip()
   475      return version
   476  
   477  def compute_capabilities(repository_ctx):
   478      """Returns a list of strings representing cuda compute capabilities."""
   479      if _TF_CUDA_COMPUTE_CAPABILITIES not in repository_ctx.os.environ:
   480          return _DEFAULT_CUDA_COMPUTE_CAPABILITIES
   481      capabilities_str = repository_ctx.os.environ[_TF_CUDA_COMPUTE_CAPABILITIES]
   482      capabilities = capabilities_str.split(",")
   483      for capability in capabilities:
   484          # Workaround for Skylark's lack of support for regex. This check should
   485          # be equivalent to checking:
   486          #     if re.match("[0-9]+.[0-9]+", capability) == None:
   487          parts = capability.split(".")
   488          if len(parts) != 2 or not parts[0].isdigit() or not parts[1].isdigit():
   489              auto_configure_fail("Invalid compute capability: %s" % capability)
   490      return capabilities
   491  
   492  def get_cpu_value(repository_ctx):
   493      """Returns the name of the host operating system.
   494  
   495        Args:
   496          repository_ctx: The repository context.
   497  
   498        Returns:
   499          A string containing the name of the host operating system.
   500        """
   501      os_name = repository_ctx.os.name.lower()
   502      if os_name.startswith("mac os"):
   503          return "Darwin"
   504      if os_name.find("windows") != -1:
   505          return "Windows"
   506      result = repository_ctx.execute(["uname", "-s"])
   507      return result.stdout.strip()
   508  
   509  def _is_windows(repository_ctx):
   510      """Returns true if the host operating system is windows."""
   511      return repository_ctx.os.name.lower().find("windows") >= 0
   512  
   513  def lib_name(base_name, cpu_value, version = None, static = False):
   514      """Constructs the platform-specific name of a library.
   515  
   516        Args:
   517          base_name: The name of the library, such as "cudart"
   518          cpu_value: The name of the host operating system.
   519          version: The version of the library.
   520          static: True the library is static or False if it is a shared object.
   521  
   522        Returns:
   523          The platform-specific name of the library.
   524        """
   525      version = "" if not version else "." + version
   526      if cpu_value in ("Linux", "FreeBSD"):
   527          if static:
   528              return "lib%s.a" % base_name
   529          return "lib%s.so%s" % (base_name, version)
   530      elif cpu_value == "Windows":
   531          return "%s.lib" % base_name
   532      elif cpu_value == "Darwin":
   533          if static:
   534              return "lib%s.a" % base_name
   535          return "lib%s%s.dylib" % (base_name, version)
   536      else:
   537          auto_configure_fail("Invalid cpu_value: %s" % cpu_value)
   538  
   539  def find_lib(repository_ctx, paths, check_soname = True):
   540      """
   541        Finds a library among a list of potential paths.
   542  
   543        Args:
   544          paths: List of paths to inspect.
   545  
   546        Returns:
   547          Returns the first path in paths that exist.
   548      """
   549      objdump = repository_ctx.which("objdump")
   550      mismatches = []
   551      for path in [repository_ctx.path(path) for path in paths]:
   552          if not path.exists:
   553              continue
   554          if check_soname and objdump != None and not _is_windows(repository_ctx):
   555              output = repository_ctx.execute([objdump, "-p", str(path)]).stdout
   556              output = [line for line in output.splitlines() if "SONAME" in line]
   557              sonames = [line.strip().split(" ")[-1] for line in output]
   558              if not any([soname == path.basename for soname in sonames]):
   559                  mismatches.append(str(path))
   560                  continue
   561          return path
   562      if mismatches:
   563          auto_configure_fail(
   564              "None of the libraries match their SONAME: " + ", ".join(mismatches),
   565          )
   566      auto_configure_fail("No library found under: " + ", ".join(paths))
   567  
   568  def _find_cuda_lib(
   569          lib,
   570          repository_ctx,
   571          cpu_value,
   572          basedir,
   573          version,
   574          static = False):
   575      """Finds the given CUDA or cuDNN library on the system.
   576  
   577        Args:
   578          lib: The name of the library, such as "cudart"
   579          repository_ctx: The repository context.
   580          cpu_value: The name of the host operating system.
   581          basedir: The install directory of CUDA or cuDNN.
   582          version: The version of the library.
   583          static: True if static library, False if shared object.
   584  
   585        Returns:
   586          Returns the path to the library.
   587        """
   588      file_name = lib_name(lib, cpu_value, version, static)
   589      return find_lib(
   590          repository_ctx,
   591          ["%s/%s" % (basedir, file_name)],
   592          check_soname = version and not static,
   593      )
   594  
   595  def _find_libs(repository_ctx, cuda_config):
   596      """Returns the CUDA and cuDNN libraries on the system.
   597  
   598        Args:
   599          repository_ctx: The repository context.
   600          cuda_config: The CUDA config as returned by _get_cuda_config
   601  
   602        Returns:
   603          Map of library names to structs of filename and path.
   604        """
   605      cpu_value = cuda_config.cpu_value
   606      stub_dir = "" if _is_windows(repository_ctx) else "/stubs"
   607      return {
   608          "cuda": _find_cuda_lib(
   609              "cuda",
   610              repository_ctx,
   611              cpu_value,
   612              cuda_config.config["cuda_library_dir"] + stub_dir,
   613              None,
   614          ),
   615          "cudart": _find_cuda_lib(
   616              "cudart",
   617              repository_ctx,
   618              cpu_value,
   619              cuda_config.config["cuda_library_dir"],
   620              cuda_config.cuda_version,
   621          ),
   622          "cudart_static": _find_cuda_lib(
   623              "cudart_static",
   624              repository_ctx,
   625              cpu_value,
   626              cuda_config.config["cuda_library_dir"],
   627              cuda_config.cuda_version,
   628              static = True,
   629          ),
   630          "cublas": _find_cuda_lib(
   631              "cublas",
   632              repository_ctx,
   633              cpu_value,
   634              cuda_config.config["cublas_library_dir"],
   635              cuda_config.cuda_lib_version,
   636          ),
   637          "cusolver": _find_cuda_lib(
   638              "cusolver",
   639              repository_ctx,
   640              cpu_value,
   641              cuda_config.config["cuda_library_dir"],
   642              cuda_config.cuda_lib_version,
   643          ),
   644          "curand": _find_cuda_lib(
   645              "curand",
   646              repository_ctx,
   647              cpu_value,
   648              cuda_config.config["cuda_library_dir"],
   649              cuda_config.cuda_lib_version,
   650          ),
   651          "cufft": _find_cuda_lib(
   652              "cufft",
   653              repository_ctx,
   654              cpu_value,
   655              cuda_config.config["cuda_library_dir"],
   656              cuda_config.cuda_lib_version,
   657          ),
   658          "cudnn": _find_cuda_lib(
   659              "cudnn",
   660              repository_ctx,
   661              cpu_value,
   662              cuda_config.config["cudnn_library_dir"],
   663              cuda_config.cudnn_version,
   664          ),
   665          "cupti": _find_cuda_lib(
   666              "cupti",
   667              repository_ctx,
   668              cpu_value,
   669              cuda_config.config["cupti_library_dir"],
   670              cuda_config.cuda_version,
   671          ),
   672          "cusparse": _find_cuda_lib(
   673              "cusparse",
   674              repository_ctx,
   675              cpu_value,
   676              cuda_config.config["cuda_library_dir"],
   677              cuda_config.cuda_lib_version,
   678          ),
   679      }
   680  
   681  def _cudart_static_linkopt(cpu_value):
   682      """Returns additional platform-specific linkopts for cudart."""
   683      return "" if cpu_value == "Darwin" else "\"-lrt\","
   684  
   685  # TODO(csigg): Only call once instead of from here, tensorrt_configure.bzl,
   686  # and nccl_configure.bzl.
   687  def find_cuda_config(repository_ctx, cuda_libraries):
   688      """Returns CUDA config dictionary from running find_cuda_config.py"""
   689      exec_result = repository_ctx.execute([
   690          _get_python_bin(repository_ctx),
   691          repository_ctx.path(Label("@org_tensorflow//third_party/gpus:find_cuda_config.py")),
   692      ] + cuda_libraries)
   693      if exec_result.return_code:
   694          auto_configure_fail("Failed to run find_cuda_config.py: %s" % exec_result.stderr)
   695  
   696      # Parse the dict from stdout.
   697      return dict([tuple(x.split(": ")) for x in exec_result.stdout.splitlines()])
   698  
   699  def _get_cuda_config(repository_ctx):
   700      """Detects and returns information about the CUDA installation on the system.
   701  
   702        Args:
   703          repository_ctx: The repository context.
   704  
   705        Returns:
   706          A struct containing the following fields:
   707            cuda_toolkit_path: The CUDA toolkit installation directory.
   708            cudnn_install_basedir: The cuDNN installation directory.
   709            cuda_version: The version of CUDA on the system.
   710            cudnn_version: The version of cuDNN on the system.
   711            compute_capabilities: A list of the system's CUDA compute capabilities.
   712            cpu_value: The name of the host operating system.
   713        """
   714      config = find_cuda_config(repository_ctx, ["cuda", "cudnn"])
   715      cpu_value = get_cpu_value(repository_ctx)
   716      toolkit_path = config["cuda_toolkit_path"]
   717  
   718      is_windows = _is_windows(repository_ctx)
   719      cuda_version = config["cuda_version"].split(".")
   720      cuda_major = cuda_version[0]
   721      cuda_minor = cuda_version[1]
   722  
   723      cuda_version = ("64_%s%s" if is_windows else "%s.%s") % (cuda_major, cuda_minor)
   724      cudnn_version = ("64_%s" if is_windows else "%s") % config["cudnn_version"]
   725  
   726      # cuda_lib_version is for libraries like cuBLAS, cuFFT, cuSOLVER, etc.
   727      # It changed from 'x.y' to just 'x' in CUDA 10.1.
   728      if (int(cuda_major), int(cuda_minor)) >= (10, 1):
   729          cuda_lib_version = ("64_%s" if is_windows else "%s") % cuda_major
   730      else:
   731          cuda_lib_version = cuda_version
   732  
   733      return struct(
   734          cuda_toolkit_path = toolkit_path,
   735          cuda_version = cuda_version,
   736          cudnn_version = cudnn_version,
   737          cuda_lib_version = cuda_lib_version,
   738          compute_capabilities = compute_capabilities(repository_ctx),
   739          cpu_value = cpu_value,
   740          config = config,
   741      )
   742  
   743  def _tpl(repository_ctx, tpl, substitutions = {}, out = None):
   744      if not out:
   745          out = tpl.replace(":", "/")
   746      repository_ctx.template(
   747          out,
   748          Label("//third_party/gpus/%s.tpl" % tpl),
   749          substitutions,
   750      )
   751  
   752  def _file(repository_ctx, label):
   753      repository_ctx.template(
   754          label.replace(":", "/"),
   755          Label("//third_party/gpus/%s.tpl" % label),
   756          {},
   757      )
   758  
   759  _DUMMY_CROSSTOOL_BZL_FILE = """
   760  def error_gpu_disabled():
   761    fail("ERROR: Building with --config=cuda but TensorFlow is not configured " +
   762         "to build with GPU support. Please re-run ./configure and enter 'Y' " +
   763         "at the prompt to build with GPU support.")
   764  
   765    native.genrule(
   766        name = "error_gen_crosstool",
   767        outs = ["CROSSTOOL"],
   768        cmd = "echo 'Should not be run.' && exit 1",
   769    )
   770  
   771    native.filegroup(
   772        name = "crosstool",
   773        srcs = [":CROSSTOOL"],
   774        output_licenses = ["unencumbered"],
   775    )
   776  """
   777  
   778  _DUMMY_CROSSTOOL_BUILD_FILE = """
   779  load("//crosstool:error_gpu_disabled.bzl", "error_gpu_disabled")
   780  
   781  error_gpu_disabled()
   782  """
   783  
   784  def _create_dummy_repository(repository_ctx):
   785      cpu_value = get_cpu_value(repository_ctx)
   786  
   787      # Set up BUILD file for cuda/.
   788      _tpl(
   789          repository_ctx,
   790          "cuda:build_defs.bzl",
   791          {
   792              "%{cuda_is_configured}": "False",
   793              "%{cuda_extra_copts}": "[]",
   794          },
   795      )
   796      _tpl(
   797          repository_ctx,
   798          "cuda:BUILD",
   799          {
   800              "%{cuda_driver_lib}": lib_name("cuda", cpu_value),
   801              "%{cudart_static_lib}": lib_name(
   802                  "cudart_static",
   803                  cpu_value,
   804                  static = True,
   805              ),
   806              "%{cudart_static_linkopt}": _cudart_static_linkopt(cpu_value),
   807              "%{cudart_lib}": lib_name("cudart", cpu_value),
   808              "%{cublas_lib}": lib_name("cublas", cpu_value),
   809              "%{cusolver_lib}": lib_name("cusolver", cpu_value),
   810              "%{cudnn_lib}": lib_name("cudnn", cpu_value),
   811              "%{cufft_lib}": lib_name("cufft", cpu_value),
   812              "%{curand_lib}": lib_name("curand", cpu_value),
   813              "%{cupti_lib}": lib_name("cupti", cpu_value),
   814              "%{cusparse_lib}": lib_name("cusparse", cpu_value),
   815              "%{copy_rules}": """
   816  filegroup(name="cuda-include")
   817  filegroup(name="cublas-include")
   818  filegroup(name="cudnn-include")
   819  """,
   820          },
   821      )
   822  
   823      # Create dummy files for the CUDA toolkit since they are still required by
   824      # tensorflow/core/platform/default/build_config:cuda.
   825      repository_ctx.file("cuda/cuda/include/cuda.h")
   826      repository_ctx.file("cuda/cuda/include/cublas.h")
   827      repository_ctx.file("cuda/cuda/include/cudnn.h")
   828      repository_ctx.file("cuda/cuda/extras/CUPTI/include/cupti.h")
   829      repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cuda", cpu_value))
   830      repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cudart", cpu_value))
   831      repository_ctx.file(
   832          "cuda/cuda/lib/%s" % lib_name("cudart_static", cpu_value),
   833      )
   834      repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cublas", cpu_value))
   835      repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cusolver", cpu_value))
   836      repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cudnn", cpu_value))
   837      repository_ctx.file("cuda/cuda/lib/%s" % lib_name("curand", cpu_value))
   838      repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cufft", cpu_value))
   839      repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cupti", cpu_value))
   840      repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cusparse", cpu_value))
   841  
   842      # Set up cuda_config.h, which is used by
   843      # tensorflow/stream_executor/dso_loader.cc.
   844      _tpl(
   845          repository_ctx,
   846          "cuda:cuda_config.h",
   847          {
   848              "%{cuda_version}": "",
   849              "%{cuda_lib_version}": "",
   850              "%{cudnn_version}": "",
   851              "%{cuda_compute_capabilities}": ",".join([
   852                  "CudaVersion(\"%s\")" % c
   853                  for c in _DEFAULT_CUDA_COMPUTE_CAPABILITIES
   854              ]),
   855              "%{cuda_toolkit_path}": "",
   856          },
   857          "cuda/cuda/cuda_config.h",
   858      )
   859  
   860      # If cuda_configure is not configured to build with GPU support, and the user
   861      # attempts to build with --config=cuda, add a dummy build rule to intercept
   862      # this and fail with an actionable error message.
   863      repository_ctx.file(
   864          "crosstool/error_gpu_disabled.bzl",
   865          _DUMMY_CROSSTOOL_BZL_FILE,
   866      )
   867      repository_ctx.file("crosstool/BUILD", _DUMMY_CROSSTOOL_BUILD_FILE)
   868  
   869  def _execute(
   870          repository_ctx,
   871          cmdline,
   872          error_msg = None,
   873          error_details = None,
   874          empty_stdout_fine = False):
   875      """Executes an arbitrary shell command.
   876  
   877        Args:
   878          repository_ctx: the repository_ctx object
   879          cmdline: list of strings, the command to execute
   880          error_msg: string, a summary of the error if the command fails
   881          error_details: string, details about the error or steps to fix it
   882          empty_stdout_fine: bool, if True, an empty stdout result is fine,
   883            otherwise it's an error
   884        Return: the result of repository_ctx.execute(cmdline)
   885      """
   886      result = repository_ctx.execute(cmdline)
   887      if result.stderr or not (empty_stdout_fine or result.stdout):
   888          auto_configure_fail(
   889              "\n".join([
   890                  error_msg.strip() if error_msg else "Repository command failed",
   891                  result.stderr.strip(),
   892                  error_details if error_details else "",
   893              ]),
   894          )
   895      return result
   896  
   897  def _norm_path(path):
   898      """Returns a path with '/' and remove the trailing slash."""
   899      path = path.replace("\\", "/")
   900      if path[-1] == "/":
   901          path = path[:-1]
   902      return path
   903  
   904  def make_copy_files_rule(repository_ctx, name, srcs, outs):
   905      """Returns a rule to copy a set of files."""
   906      cmds = []
   907  
   908      # Copy files.
   909      for src, out in zip(srcs, outs):
   910          cmds.append('cp -f "%s" "$(location %s)"' % (src, out))
   911      outs = [('        "%s",' % out) for out in outs]
   912      return """genrule(
   913      name = "%s",
   914      outs = [
   915  %s
   916      ],
   917      cmd = \"""%s \""",
   918  )""" % (name, "\n".join(outs), " && \\\n".join(cmds))
   919  
   920  def make_copy_dir_rule(repository_ctx, name, src_dir, out_dir):
   921      """Returns a rule to recursively copy a directory."""
   922      src_dir = _norm_path(src_dir)
   923      out_dir = _norm_path(out_dir)
   924      outs = _read_dir(repository_ctx, src_dir)
   925      outs = [('        "%s",' % out.replace(src_dir, out_dir)) for out in outs]
   926  
   927      # '@D' already contains the relative path for a single file, see
   928      # http://docs.bazel.build/versions/master/be/make-variables.html#predefined_genrule_variables
   929      out_dir = "$(@D)/%s" % out_dir if len(outs) > 1 else "$(@D)"
   930      return """genrule(
   931      name = "%s",
   932      outs = [
   933  %s
   934      ],
   935      cmd = \"""cp -rLf "%s/." "%s/" \""",
   936  )""" % (name, "\n".join(outs), src_dir, out_dir)
   937  
   938  def _read_dir(repository_ctx, src_dir):
   939      """Returns a string with all files in a directory.
   940  
   941        Finds all files inside a directory, traversing subfolders and following
   942        symlinks. The returned string contains the full path of all files
   943        separated by line breaks.
   944        """
   945      if _is_windows(repository_ctx):
   946          src_dir = src_dir.replace("/", "\\")
   947          find_result = _execute(
   948              repository_ctx,
   949              ["cmd.exe", "/c", "dir", src_dir, "/b", "/s", "/a-d"],
   950              empty_stdout_fine = True,
   951          )
   952  
   953          # src_files will be used in genrule.outs where the paths must
   954          # use forward slashes.
   955          result = find_result.stdout.replace("\\", "/")
   956      else:
   957          find_result = _execute(
   958              repository_ctx,
   959              ["find", src_dir, "-follow", "-type", "f"],
   960              empty_stdout_fine = True,
   961          )
   962          result = find_result.stdout
   963      return sorted(result.splitlines())
   964  
   965  def _flag_enabled(repository_ctx, flag_name):
   966      if flag_name in repository_ctx.os.environ:
   967          value = repository_ctx.os.environ[flag_name].strip()
   968          return value == "1"
   969      return False
   970  
   971  def _use_cuda_clang(repository_ctx):
   972      return _flag_enabled(repository_ctx, "TF_CUDA_CLANG")
   973  
   974  def _compute_cuda_extra_copts(repository_ctx, compute_capabilities):
   975      if _use_cuda_clang(repository_ctx):
   976          capability_flags = [
   977              "--cuda-gpu-arch=sm_" + cap.replace(".", "")
   978              for cap in compute_capabilities
   979          ]
   980      else:
   981          # Capabilities are handled in the "crosstool_wrapper_driver_is_not_gcc" for nvcc
   982          # TODO(csigg): Make this consistent with cuda clang and pass to crosstool.
   983          capability_flags = []
   984      return str(capability_flags)
   985  
   986  def _create_local_cuda_repository(repository_ctx):
   987      """Creates the repository containing files set up to build with CUDA."""
   988      cuda_config = _get_cuda_config(repository_ctx)
   989  
   990      cuda_include_path = cuda_config.config["cuda_include_dir"]
   991      cublas_include_path = cuda_config.config["cublas_include_dir"]
   992      cudnn_header_dir = cuda_config.config["cudnn_include_dir"]
   993      cupti_header_dir = cuda_config.config["cupti_include_dir"]
   994      nvvm_libdevice_dir = cuda_config.config["nvvm_library_dir"]
   995  
   996      # Create genrule to copy files from the installed CUDA toolkit into execroot.
   997      copy_rules = [
   998          make_copy_dir_rule(
   999              repository_ctx,
  1000              name = "cuda-include",
  1001              src_dir = cuda_include_path,
  1002              out_dir = "cuda/include",
  1003          ),
  1004          make_copy_dir_rule(
  1005              repository_ctx,
  1006              name = "cuda-nvvm",
  1007              src_dir = nvvm_libdevice_dir,
  1008              out_dir = "cuda/nvvm/libdevice",
  1009          ),
  1010          make_copy_dir_rule(
  1011              repository_ctx,
  1012              name = "cuda-extras",
  1013              src_dir = cupti_header_dir,
  1014              out_dir = "cuda/extras/CUPTI/include",
  1015          ),
  1016      ]
  1017  
  1018      copy_rules.append(make_copy_files_rule(
  1019          repository_ctx,
  1020          name = "cublas-include",
  1021          srcs = [
  1022              cublas_include_path + "/cublas.h",
  1023              cublas_include_path + "/cublas_v2.h",
  1024              cublas_include_path + "/cublas_api.h",
  1025          ],
  1026          outs = [
  1027              "cublas/include/cublas.h",
  1028              "cublas/include/cublas_v2.h",
  1029              "cublas/include/cublas_api.h",
  1030          ],
  1031      ))
  1032  
  1033      cuda_libs = _find_libs(repository_ctx, cuda_config)
  1034      cuda_lib_srcs = []
  1035      cuda_lib_outs = []
  1036      for path in cuda_libs.values():
  1037          cuda_lib_srcs.append(str(path))
  1038          cuda_lib_outs.append("cuda/lib/" + path.basename)
  1039      copy_rules.append(make_copy_files_rule(
  1040          repository_ctx,
  1041          name = "cuda-lib",
  1042          srcs = cuda_lib_srcs,
  1043          outs = cuda_lib_outs,
  1044      ))
  1045  
  1046      copy_rules.append(make_copy_dir_rule(
  1047          repository_ctx,
  1048          name = "cuda-bin",
  1049          src_dir = cuda_config.cuda_toolkit_path + "/bin",
  1050          out_dir = "cuda/bin",
  1051      ))
  1052  
  1053      copy_rules.append(make_copy_files_rule(
  1054          repository_ctx,
  1055          name = "cudnn-include",
  1056          srcs = [cudnn_header_dir + "/cudnn.h"],
  1057          outs = ["cudnn/include/cudnn.h"],
  1058      ))
  1059  
  1060      # Set up BUILD file for cuda/
  1061      _tpl(
  1062          repository_ctx,
  1063          "cuda:build_defs.bzl",
  1064          {
  1065              "%{cuda_is_configured}": "True",
  1066              "%{cuda_extra_copts}": _compute_cuda_extra_copts(
  1067                  repository_ctx,
  1068                  cuda_config.compute_capabilities,
  1069              ),
  1070          },
  1071      )
  1072      _tpl(
  1073          repository_ctx,
  1074          "cuda:BUILD.windows" if _is_windows(repository_ctx) else "cuda:BUILD",
  1075          {
  1076              "%{cuda_driver_lib}": cuda_libs["cuda"].basename,
  1077              "%{cudart_static_lib}": cuda_libs["cudart_static"].basename,
  1078              "%{cudart_static_linkopt}": _cudart_static_linkopt(cuda_config.cpu_value),
  1079              "%{cudart_lib}": cuda_libs["cudart"].basename,
  1080              "%{cublas_lib}": cuda_libs["cublas"].basename,
  1081              "%{cusolver_lib}": cuda_libs["cusolver"].basename,
  1082              "%{cudnn_lib}": cuda_libs["cudnn"].basename,
  1083              "%{cufft_lib}": cuda_libs["cufft"].basename,
  1084              "%{curand_lib}": cuda_libs["curand"].basename,
  1085              "%{cupti_lib}": cuda_libs["cupti"].basename,
  1086              "%{cusparse_lib}": cuda_libs["cusparse"].basename,
  1087              "%{copy_rules}": "\n".join(copy_rules),
  1088          },
  1089          "cuda/BUILD",
  1090      )
  1091  
  1092      is_cuda_clang = _use_cuda_clang(repository_ctx)
  1093  
  1094      should_download_clang = is_cuda_clang and _flag_enabled(
  1095          repository_ctx,
  1096          _TF_DOWNLOAD_CLANG,
  1097      )
  1098      if should_download_clang:
  1099          download_clang(repository_ctx, "crosstool/extra_tools")
  1100  
  1101      # Set up crosstool/
  1102      cc = find_cc(repository_ctx)
  1103      cc_fullpath = cc if not should_download_clang else "crosstool/" + cc
  1104  
  1105      host_compiler_includes = get_cxx_inc_directories(repository_ctx, cc_fullpath)
  1106      cuda_defines = {}
  1107  
  1108      host_compiler_prefix = "/usr/bin"
  1109      if _GCC_HOST_COMPILER_PREFIX in repository_ctx.os.environ:
  1110          host_compiler_prefix = repository_ctx.os.environ[_GCC_HOST_COMPILER_PREFIX].strip()
  1111      cuda_defines["%{host_compiler_prefix}"] = host_compiler_prefix
  1112  
  1113      # Bazel sets '-B/usr/bin' flag to workaround build errors on RHEL (see
  1114      # https://github.com/bazelbuild/bazel/issues/760).
  1115      # However, this stops our custom clang toolchain from picking the provided
  1116      # LLD linker, so we're only adding '-B/usr/bin' when using non-downloaded
  1117      # toolchain.
  1118      # TODO: when bazel stops adding '-B/usr/bin' by default, remove this
  1119      #       flag from the CROSSTOOL completely (see
  1120      #       https://github.com/bazelbuild/bazel/issues/5634)
  1121      if should_download_clang:
  1122          cuda_defines["%{linker_bin_path}"] = ""
  1123      else:
  1124          cuda_defines["%{linker_bin_path}"] = host_compiler_prefix
  1125  
  1126      cuda_defines["%{extra_no_canonical_prefixes_flags}"] = ""
  1127      cuda_defines["%{unfiltered_compile_flags}"] = ""
  1128      if is_cuda_clang:
  1129          cuda_defines["%{host_compiler_path}"] = str(cc)
  1130          cuda_defines["%{host_compiler_warnings}"] = """
  1131          # Some parts of the codebase set -Werror and hit this warning, so
  1132          # switch it off for now.
  1133          "-Wno-invalid-partial-specialization"
  1134      """
  1135          cuda_defines["%{cxx_builtin_include_directories}"] = to_list_of_strings(host_compiler_includes)
  1136          cuda_defines["%{linker_files}"] = ":empty"
  1137          cuda_defines["%{win_linker_files}"] = ":empty"
  1138          repository_ctx.file(
  1139              "crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc",
  1140              "",
  1141          )
  1142          repository_ctx.file("crosstool/windows/msvc_wrapper_for_nvcc.py", "")
  1143      else:
  1144          cuda_defines["%{host_compiler_path}"] = "clang/bin/crosstool_wrapper_driver_is_not_gcc"
  1145          cuda_defines["%{host_compiler_warnings}"] = ""
  1146  
  1147          # nvcc has the system include paths built in and will automatically
  1148          # search them; we cannot work around that, so we add the relevant cuda
  1149          # system paths to the allowed compiler specific include paths.
  1150          cuda_defines["%{cxx_builtin_include_directories}"] = to_list_of_strings(
  1151              host_compiler_includes + _cuda_include_path(
  1152                  repository_ctx,
  1153                  cuda_config,
  1154              ) + [cupti_header_dir, cudnn_header_dir],
  1155          )
  1156  
  1157          # For gcc, do not canonicalize system header paths; some versions of gcc
  1158          # pick the shortest possible path for system includes when creating the
  1159          # .d file - given that includes that are prefixed with "../" multiple
  1160          # time quickly grow longer than the root of the tree, this can lead to
  1161          # bazel's header check failing.
  1162          cuda_defines["%{extra_no_canonical_prefixes_flags}"] = "\"-fno-canonical-system-headers\""
  1163  
  1164          nvcc_path = str(
  1165              repository_ctx.path("%s/nvcc%s" % (
  1166                  cuda_config.config["cuda_binary_dir"],
  1167                  ".exe" if _is_windows(repository_ctx) else "",
  1168              )),
  1169          )
  1170          cuda_defines["%{linker_files}"] = ":crosstool_wrapper_driver_is_not_gcc"
  1171          cuda_defines["%{win_linker_files}"] = ":windows_msvc_wrapper_files"
  1172  
  1173          wrapper_defines = {
  1174              "%{cpu_compiler}": str(cc),
  1175              "%{cuda_version}": cuda_config.cuda_version,
  1176              "%{nvcc_path}": nvcc_path,
  1177              "%{gcc_host_compiler_path}": str(cc),
  1178              "%{cuda_compute_capabilities}": ", ".join(
  1179                  ["\"%s\"" % c for c in cuda_config.compute_capabilities],
  1180              ),
  1181              "%{nvcc_tmp_dir}": _get_nvcc_tmp_dir_for_windows(repository_ctx),
  1182          }
  1183          _tpl(
  1184              repository_ctx,
  1185              "crosstool:clang/bin/crosstool_wrapper_driver_is_not_gcc",
  1186              wrapper_defines,
  1187          )
  1188          _tpl(
  1189              repository_ctx,
  1190              "crosstool:windows/msvc_wrapper_for_nvcc.py",
  1191              wrapper_defines,
  1192          )
  1193  
  1194      cuda_defines.update(_get_win_cuda_defines(repository_ctx))
  1195  
  1196      verify_build_defines(cuda_defines)
  1197  
  1198      # Only expand template variables in the BUILD file
  1199      _tpl(repository_ctx, "crosstool:BUILD", cuda_defines)
  1200  
  1201      # No templating of cc_toolchain_config - use attributes and templatize the
  1202      # BUILD file.
  1203      _file(repository_ctx, "crosstool:cc_toolchain_config.bzl")
  1204  
  1205      # Set up cuda_config.h, which is used by
  1206      # tensorflow/stream_executor/dso_loader.cc.
  1207      _tpl(
  1208          repository_ctx,
  1209          "cuda:cuda_config.h",
  1210          {
  1211              "%{cuda_version}": cuda_config.cuda_version,
  1212              "%{cuda_lib_version}": cuda_config.cuda_lib_version,
  1213              "%{cudnn_version}": cuda_config.cudnn_version,
  1214              "%{cuda_compute_capabilities}": ", ".join([
  1215                  "CudaVersion(\"%s\")" % c
  1216                  for c in cuda_config.compute_capabilities
  1217              ]),
  1218              "%{cuda_toolkit_path}": cuda_config.cuda_toolkit_path,
  1219          },
  1220          "cuda/cuda/cuda_config.h",
  1221      )
  1222  
  1223  def _create_remote_cuda_repository(repository_ctx, remote_config_repo):
  1224      """Creates pointers to a remotely configured repo set up to build with CUDA."""
  1225      _tpl(
  1226          repository_ctx,
  1227          "cuda:build_defs.bzl",
  1228          {
  1229              "%{cuda_is_configured}": "True",
  1230              "%{cuda_extra_copts}": _compute_cuda_extra_copts(
  1231                  repository_ctx,
  1232                  compute_capabilities(repository_ctx),
  1233              ),
  1234          },
  1235      )
  1236      repository_ctx.template(
  1237          "cuda/BUILD",
  1238          Label(remote_config_repo + "/cuda:BUILD"),
  1239          {},
  1240      )
  1241      repository_ctx.template(
  1242          "cuda/build_defs.bzl",
  1243          Label(remote_config_repo + "/cuda:build_defs.bzl"),
  1244          {},
  1245      )
  1246      repository_ctx.template(
  1247          "cuda/cuda/cuda_config.h",
  1248          Label(remote_config_repo + "/cuda:cuda/cuda_config.h"),
  1249          {},
  1250      )
  1251  
  1252  def _cuda_autoconf_impl(repository_ctx):
  1253      """Implementation of the cuda_autoconf repository rule."""
  1254      if not enable_cuda(repository_ctx):
  1255          _create_dummy_repository(repository_ctx)
  1256      elif _TF_CUDA_CONFIG_REPO in repository_ctx.os.environ:
  1257          if (_TF_CUDA_VERSION not in repository_ctx.os.environ or
  1258              _TF_CUDNN_VERSION not in repository_ctx.os.environ):
  1259              auto_configure_fail("%s and %s must also be set if %s is specified" %
  1260                                  (_TF_CUDA_VERSION, _TF_CUDNN_VERSION, _TF_CUDA_CONFIG_REPO))
  1261          _create_remote_cuda_repository(
  1262              repository_ctx,
  1263              repository_ctx.os.environ[_TF_CUDA_CONFIG_REPO],
  1264          )
  1265      else:
  1266          _create_local_cuda_repository(repository_ctx)
  1267  
  1268  cuda_configure = repository_rule(
  1269      implementation = _cuda_autoconf_impl,
  1270      environ = [
  1271          _GCC_HOST_COMPILER_PATH,
  1272          _GCC_HOST_COMPILER_PREFIX,
  1273          _CLANG_CUDA_COMPILER_PATH,
  1274          "TF_NEED_CUDA",
  1275          "TF_CUDA_CLANG",
  1276          _TF_DOWNLOAD_CLANG,
  1277          _CUDA_TOOLKIT_PATH,
  1278          _CUDNN_INSTALL_PATH,
  1279          _TF_CUDA_VERSION,
  1280          _TF_CUDNN_VERSION,
  1281          _TF_CUDA_COMPUTE_CAPABILITIES,
  1282          _TF_CUDA_CONFIG_REPO,
  1283          "NVVMIR_LIBRARY_DIR",
  1284          _PYTHON_BIN_PATH,
  1285          "TMP",
  1286          "TMPDIR",
  1287          "TF_CUDA_PATHS",
  1288      ],
  1289  )
  1290  
  1291  """Detects and configures the local CUDA toolchain.
  1292  
  1293  Add the following to your WORKSPACE FILE:
  1294  
  1295  ```python
  1296  cuda_configure(name = "local_config_cuda")
  1297  ```
  1298  
  1299  Args:
  1300    name: A unique name for this workspace rule.
  1301  """