github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/gpus/cuda_configure.bzl (about) 1 # -*- Python -*- 2 """Repository rule for CUDA autoconfiguration. 3 4 `cuda_configure` depends on the following environment variables: 5 6 * `TF_NEED_CUDA`: Whether to enable building with CUDA. 7 * `GCC_HOST_COMPILER_PATH`: The GCC host compiler path 8 * `TF_CUDA_CLANG`: Whether to use clang as a cuda compiler. 9 * `CLANG_CUDA_COMPILER_PATH`: The clang compiler path that will be used for 10 both host and device code compilation if TF_CUDA_CLANG is 1. 11 * `TF_DOWNLOAD_CLANG`: Whether to download a recent release of clang 12 compiler and use it to build tensorflow. When this option is set 13 CLANG_CUDA_COMPILER_PATH is ignored. 14 * `TF_CUDA_PATHS`: The base paths to look for CUDA and cuDNN. Default is 15 `/usr/local/cuda,usr/`. 16 * `CUDA_TOOLKIT_PATH` (deprecated): The path to the CUDA toolkit. Default is 17 `/usr/local/cuda`. 18 * `TF_CUDA_VERSION`: The version of the CUDA toolkit. If this is blank, then 19 use the system default. 20 * `TF_CUDNN_VERSION`: The version of the cuDNN library. 21 * `CUDNN_INSTALL_PATH` (deprecated): The path to the cuDNN library. Default is 22 `/usr/local/cuda`. 23 * `TF_CUDA_COMPUTE_CAPABILITIES`: The CUDA compute capabilities. Default is 24 `3.5,5.2`. 25 * `PYTHON_BIN_PATH`: The python binary path 26 """ 27 28 load("//third_party/clang_toolchain:download_clang.bzl", "download_clang") 29 load( 30 "@bazel_tools//tools/cpp:lib_cc_configure.bzl", 31 "escape_string", 32 "get_env_var", 33 ) 34 load( 35 "@bazel_tools//tools/cpp:windows_cc_configure.bzl", 36 "find_msvc_tool", 37 "find_vc_path", 38 "setup_vc_env_vars", 39 ) 40 41 _GCC_HOST_COMPILER_PATH = "GCC_HOST_COMPILER_PATH" 42 _GCC_HOST_COMPILER_PREFIX = "GCC_HOST_COMPILER_PREFIX" 43 _CLANG_CUDA_COMPILER_PATH = "CLANG_CUDA_COMPILER_PATH" 44 _CUDA_TOOLKIT_PATH = "CUDA_TOOLKIT_PATH" 45 _TF_CUDA_VERSION = "TF_CUDA_VERSION" 46 _TF_CUDNN_VERSION = "TF_CUDNN_VERSION" 47 _CUDNN_INSTALL_PATH = "CUDNN_INSTALL_PATH" 48 _TF_CUDA_COMPUTE_CAPABILITIES = "TF_CUDA_COMPUTE_CAPABILITIES" 49 _TF_CUDA_CONFIG_REPO = "TF_CUDA_CONFIG_REPO" 50 _TF_DOWNLOAD_CLANG = "TF_DOWNLOAD_CLANG" 51 _PYTHON_BIN_PATH = "PYTHON_BIN_PATH" 52 53 _DEFAULT_CUDA_COMPUTE_CAPABILITIES = ["3.5", "5.2"] 54 55 def to_list_of_strings(elements): 56 """Convert the list of ["a", "b", "c"] into '"a", "b", "c"'. 57 58 This is to be used to put a list of strings into the bzl file templates 59 so it gets interpreted as list of strings in Starlark. 60 61 Args: 62 elements: list of string elements 63 64 Returns: 65 single string of elements wrapped in quotes separated by a comma.""" 66 quoted_strings = ["\"" + element + "\"" for element in elements] 67 return ", ".join(quoted_strings) 68 69 def verify_build_defines(params): 70 """Verify all variables that crosstool/BUILD.tpl expects are substituted. 71 72 Args: 73 params: dict of variables that will be passed to the BUILD.tpl template. 74 """ 75 missing = [] 76 for param in [ 77 "cxx_builtin_include_directories", 78 "extra_no_canonical_prefixes_flags", 79 "host_compiler_path", 80 "host_compiler_prefix", 81 "host_compiler_warnings", 82 "linker_bin_path", 83 "linker_files", 84 "msvc_cl_path", 85 "msvc_env_include", 86 "msvc_env_lib", 87 "msvc_env_path", 88 "msvc_env_tmp", 89 "msvc_lib_path", 90 "msvc_link_path", 91 "msvc_ml_path", 92 "unfiltered_compile_flags", 93 "win_linker_files", 94 ]: 95 if ("%{" + param + "}") not in params: 96 missing.append(param) 97 98 if missing: 99 auto_configure_fail( 100 "BUILD.tpl template is missing these variables: " + 101 str(missing) + 102 ".\nWe only got: " + 103 str(params) + 104 ".", 105 ) 106 107 def _get_python_bin(repository_ctx): 108 """Gets the python bin path.""" 109 python_bin = repository_ctx.os.environ.get(_PYTHON_BIN_PATH) 110 if python_bin != None: 111 return python_bin 112 python_bin_name = "python.exe" if _is_windows(repository_ctx) else "python" 113 python_bin_path = repository_ctx.which(python_bin_name) 114 if python_bin_path != None: 115 return str(python_bin_path) 116 auto_configure_fail( 117 "Cannot find python in PATH, please make sure " + 118 "python is installed and add its directory in PATH, or --define " + 119 "%s='/something/else'.\nPATH=%s" % ( 120 _PYTHON_BIN_PATH, 121 repository_ctx.os.environ.get("PATH", ""), 122 ), 123 ) 124 125 def _get_nvcc_tmp_dir_for_windows(repository_ctx): 126 """Return the Windows tmp directory for nvcc to generate intermediate source files.""" 127 escaped_tmp_dir = escape_string( 128 get_env_var(repository_ctx, "TMP", "C:\\Windows\\Temp").replace( 129 "\\", 130 "\\\\", 131 ), 132 ) 133 return escaped_tmp_dir + "\\\\nvcc_inter_files_tmp_dir" 134 135 def _get_nvcc_tmp_dir_for_unix(repository_ctx): 136 """Return the UNIX tmp directory for nvcc to generate intermediate source files.""" 137 escaped_tmp_dir = escape_string( 138 get_env_var(repository_ctx, "TMPDIR", "/tmp"), 139 ) 140 return escaped_tmp_dir + "/nvcc_inter_files_tmp_dir" 141 142 def _get_msvc_compiler(repository_ctx): 143 vc_path = find_vc_path(repository_ctx) 144 return find_msvc_tool(repository_ctx, vc_path, "cl.exe").replace("\\", "/") 145 146 def _get_win_cuda_defines(repository_ctx): 147 """Return CROSSTOOL defines for Windows""" 148 149 # If we are not on Windows, return fake vaules for Windows specific fields. 150 # This ensures the CROSSTOOL file parser is happy. 151 if not _is_windows(repository_ctx): 152 return { 153 "%{msvc_env_tmp}": "msvc_not_used", 154 "%{msvc_env_path}": "msvc_not_used", 155 "%{msvc_env_include}": "msvc_not_used", 156 "%{msvc_env_lib}": "msvc_not_used", 157 "%{msvc_cl_path}": "msvc_not_used", 158 "%{msvc_ml_path}": "msvc_not_used", 159 "%{msvc_link_path}": "msvc_not_used", 160 "%{msvc_lib_path}": "msvc_not_used", 161 } 162 163 vc_path = find_vc_path(repository_ctx) 164 if not vc_path: 165 auto_configure_fail( 166 "Visual C++ build tools not found on your machine." + 167 "Please check your installation following https://docs.bazel.build/versions/master/windows.html#using", 168 ) 169 return {} 170 171 env = setup_vc_env_vars(repository_ctx, vc_path) 172 escaped_paths = escape_string(env["PATH"]) 173 escaped_include_paths = escape_string(env["INCLUDE"]) 174 escaped_lib_paths = escape_string(env["LIB"]) 175 escaped_tmp_dir = escape_string( 176 get_env_var(repository_ctx, "TMP", "C:\\Windows\\Temp").replace( 177 "\\", 178 "\\\\", 179 ), 180 ) 181 182 msvc_cl_path = _get_python_bin(repository_ctx) 183 msvc_ml_path = find_msvc_tool(repository_ctx, vc_path, "ml64.exe").replace( 184 "\\", 185 "/", 186 ) 187 msvc_link_path = find_msvc_tool(repository_ctx, vc_path, "link.exe").replace( 188 "\\", 189 "/", 190 ) 191 msvc_lib_path = find_msvc_tool(repository_ctx, vc_path, "lib.exe").replace( 192 "\\", 193 "/", 194 ) 195 196 # nvcc will generate some temporary source files under %{nvcc_tmp_dir} 197 # The generated files are guaranteed to have unique name, so they can share 198 # the same tmp directory 199 escaped_cxx_include_directories = [ 200 _get_nvcc_tmp_dir_for_windows(repository_ctx), 201 ] 202 for path in escaped_include_paths.split(";"): 203 if path: 204 escaped_cxx_include_directories.append(path) 205 206 return { 207 "%{msvc_env_tmp}": escaped_tmp_dir, 208 "%{msvc_env_path}": escaped_paths, 209 "%{msvc_env_include}": escaped_include_paths, 210 "%{msvc_env_lib}": escaped_lib_paths, 211 "%{msvc_cl_path}": msvc_cl_path, 212 "%{msvc_ml_path}": msvc_ml_path, 213 "%{msvc_link_path}": msvc_link_path, 214 "%{msvc_lib_path}": msvc_lib_path, 215 "%{cxx_builtin_include_directories}": to_list_of_strings( 216 escaped_cxx_include_directories, 217 ), 218 } 219 220 # TODO(dzc): Once these functions have been factored out of Bazel's 221 # cc_configure.bzl, load them from @bazel_tools instead. 222 # BEGIN cc_configure common functions. 223 def find_cc(repository_ctx): 224 """Find the C++ compiler.""" 225 if _is_windows(repository_ctx): 226 return _get_msvc_compiler(repository_ctx) 227 228 if _use_cuda_clang(repository_ctx): 229 target_cc_name = "clang" 230 cc_path_envvar = _CLANG_CUDA_COMPILER_PATH 231 if _flag_enabled(repository_ctx, _TF_DOWNLOAD_CLANG): 232 return "extra_tools/bin/clang" 233 else: 234 target_cc_name = "gcc" 235 cc_path_envvar = _GCC_HOST_COMPILER_PATH 236 cc_name = target_cc_name 237 238 if cc_path_envvar in repository_ctx.os.environ: 239 cc_name_from_env = repository_ctx.os.environ[cc_path_envvar].strip() 240 if cc_name_from_env: 241 cc_name = cc_name_from_env 242 if cc_name.startswith("/"): 243 # Absolute path, maybe we should make this supported by our which function. 244 return cc_name 245 cc = repository_ctx.which(cc_name) 246 if cc == None: 247 fail(("Cannot find {}, either correct your path or set the {}" + 248 " environment variable").format(target_cc_name, cc_path_envvar)) 249 return cc 250 251 _INC_DIR_MARKER_BEGIN = "#include <...>" 252 253 # OSX add " (framework directory)" at the end of line, strip it. 254 _OSX_FRAMEWORK_SUFFIX = " (framework directory)" 255 _OSX_FRAMEWORK_SUFFIX_LEN = len(_OSX_FRAMEWORK_SUFFIX) 256 257 def _cxx_inc_convert(path): 258 """Convert path returned by cc -E xc++ in a complete path.""" 259 path = path.strip() 260 if path.endswith(_OSX_FRAMEWORK_SUFFIX): 261 path = path[:-_OSX_FRAMEWORK_SUFFIX_LEN].strip() 262 return path 263 264 def _normalize_include_path(repository_ctx, path): 265 """Normalizes include paths before writing them to the crosstool. 266 267 If path points inside the 'crosstool' folder of the repository, a relative 268 path is returned. 269 If path points outside the 'crosstool' folder, an absolute path is returned. 270 """ 271 path = str(repository_ctx.path(path)) 272 crosstool_folder = str(repository_ctx.path(".").get_child("crosstool")) 273 274 if path.startswith(crosstool_folder): 275 # We drop the path to "$REPO/crosstool" and a trailing path separator. 276 return path[len(crosstool_folder) + 1:] 277 return path 278 279 def _get_cxx_inc_directories_impl(repository_ctx, cc, lang_is_cpp): 280 """Compute the list of default C or C++ include directories.""" 281 if lang_is_cpp: 282 lang = "c++" 283 else: 284 lang = "c" 285 result = repository_ctx.execute([cc, "-E", "-x" + lang, "-", "-v"]) 286 index1 = result.stderr.find(_INC_DIR_MARKER_BEGIN) 287 if index1 == -1: 288 return [] 289 index1 = result.stderr.find("\n", index1) 290 if index1 == -1: 291 return [] 292 index2 = result.stderr.rfind("\n ") 293 if index2 == -1 or index2 < index1: 294 return [] 295 index2 = result.stderr.find("\n", index2 + 1) 296 if index2 == -1: 297 inc_dirs = result.stderr[index1 + 1:] 298 else: 299 inc_dirs = result.stderr[index1 + 1:index2].strip() 300 301 return [ 302 _normalize_include_path(repository_ctx, _cxx_inc_convert(p)) 303 for p in inc_dirs.split("\n") 304 ] 305 306 def get_cxx_inc_directories(repository_ctx, cc): 307 """Compute the list of default C and C++ include directories.""" 308 309 # For some reason `clang -xc` sometimes returns include paths that are 310 # different from the ones from `clang -xc++`. (Symlink and a dir) 311 # So we run the compiler with both `-xc` and `-xc++` and merge resulting lists 312 includes_cpp = _get_cxx_inc_directories_impl(repository_ctx, cc, True) 313 includes_c = _get_cxx_inc_directories_impl(repository_ctx, cc, False) 314 315 return includes_cpp + [ 316 inc 317 for inc in includes_c 318 if inc not in includes_cpp 319 ] 320 321 def auto_configure_fail(msg): 322 """Output failure message when cuda configuration fails.""" 323 red = "\033[0;31m" 324 no_color = "\033[0m" 325 fail("\n%sCuda Configuration Error:%s %s\n" % (red, no_color, msg)) 326 327 # END cc_configure common functions (see TODO above). 328 329 def _cuda_include_path(repository_ctx, cuda_config): 330 """Generates the Starlark string with cuda include directories. 331 332 Args: 333 repository_ctx: The repository context. 334 cc: The path to the gcc host compiler. 335 336 Returns: 337 A list of the gcc host compiler include directories. 338 """ 339 nvcc_path = repository_ctx.path("%s/bin/nvcc%s" % ( 340 cuda_config.cuda_toolkit_path, 341 ".exe" if cuda_config.cpu_value == "Windows" else "", 342 )) 343 result = repository_ctx.execute([ 344 nvcc_path, 345 "-v", 346 "/dev/null", 347 "-o", 348 "/dev/null", 349 ]) 350 target_dir = "" 351 for one_line in result.stderr.splitlines(): 352 if one_line.startswith("#$ _TARGET_DIR_="): 353 target_dir = ( 354 cuda_config.cuda_toolkit_path + "/" + one_line.replace( 355 "#$ _TARGET_DIR_=", 356 "", 357 ) + "/include" 358 ) 359 inc_entries = [] 360 if target_dir != "": 361 inc_entries.append(target_dir) 362 inc_entries.append(cuda_config.cuda_toolkit_path + "/include") 363 return inc_entries 364 365 def enable_cuda(repository_ctx): 366 """Returns whether to build with CUDA support.""" 367 return int(repository_ctx.os.environ.get("TF_NEED_CUDA", False)) 368 369 def matches_version(environ_version, detected_version): 370 """Checks whether the user-specified version matches the detected version. 371 372 This function performs a weak matching so that if the user specifies only 373 the 374 major or major and minor versions, the versions are still considered 375 matching 376 if the version parts match. To illustrate: 377 378 environ_version detected_version result 379 ----------------------------------------- 380 5.1.3 5.1.3 True 381 5.1 5.1.3 True 382 5 5.1 True 383 5.1.3 5.1 False 384 5.2.3 5.1.3 False 385 386 Args: 387 environ_version: The version specified by the user via environment 388 variables. 389 detected_version: The version autodetected from the CUDA installation on 390 the system. 391 Returns: True if user-specified version matches detected version and False 392 otherwise. 393 """ 394 environ_version_parts = environ_version.split(".") 395 detected_version_parts = detected_version.split(".") 396 if len(detected_version_parts) < len(environ_version_parts): 397 return False 398 for i, part in enumerate(detected_version_parts): 399 if i >= len(environ_version_parts): 400 break 401 if part != environ_version_parts[i]: 402 return False 403 return True 404 405 _NVCC_VERSION_PREFIX = "Cuda compilation tools, release " 406 407 _DEFINE_CUDNN_MAJOR = "#define CUDNN_MAJOR" 408 409 def find_cuda_define(repository_ctx, header_dir, header_file, define): 410 """Returns the value of a #define in a header file. 411 412 Greps through a header file and returns the value of the specified #define. 413 If the #define is not found, then raise an error. 414 415 Args: 416 repository_ctx: The repository context. 417 header_dir: The directory containing the header file. 418 header_file: The header file name. 419 define: The #define to search for. 420 421 Returns: 422 The value of the #define found in the header. 423 """ 424 425 # Confirm location of the header and grep for the line defining the macro. 426 h_path = repository_ctx.path("%s/%s" % (header_dir, header_file)) 427 if not h_path.exists: 428 auto_configure_fail("Cannot find %s at %s" % (header_file, str(h_path))) 429 result = repository_ctx.execute( 430 # Grep one more lines as some #defines are splitted into two lines. 431 [ 432 "grep", 433 "--color=never", 434 "-A1", 435 "-E", 436 define, 437 str(h_path), 438 ], 439 ) 440 if result.stderr: 441 auto_configure_fail("Error reading %s: %s" % (str(h_path), result.stderr)) 442 443 # Parse the version from the line defining the macro. 444 if result.stdout.find(define) == -1: 445 auto_configure_fail( 446 "Cannot find line containing '%s' in %s" % (define, h_path), 447 ) 448 449 # Split results to lines 450 lines = result.stdout.split("\n") 451 num_lines = len(lines) 452 for l in range(num_lines): 453 line = lines[l] 454 if define in line: # Find the line with define 455 version = line 456 if l != num_lines - 1 and line[-1] == "\\": # Add next line, if multiline 457 version = version[:-1] + lines[l + 1] 458 break 459 460 # Remove any comments 461 version = version.split("//")[0] 462 463 # Remove define name 464 version = version.replace(define, "").strip() 465 466 # Remove the code after the version number. 467 version_end = version.find(" ") 468 if version_end != -1: 469 if version_end == 0: 470 auto_configure_fail( 471 "Cannot extract the version from line containing '%s' in %s" % 472 (define, str(h_path)), 473 ) 474 version = version[:version_end].strip() 475 return version 476 477 def compute_capabilities(repository_ctx): 478 """Returns a list of strings representing cuda compute capabilities.""" 479 if _TF_CUDA_COMPUTE_CAPABILITIES not in repository_ctx.os.environ: 480 return _DEFAULT_CUDA_COMPUTE_CAPABILITIES 481 capabilities_str = repository_ctx.os.environ[_TF_CUDA_COMPUTE_CAPABILITIES] 482 capabilities = capabilities_str.split(",") 483 for capability in capabilities: 484 # Workaround for Skylark's lack of support for regex. This check should 485 # be equivalent to checking: 486 # if re.match("[0-9]+.[0-9]+", capability) == None: 487 parts = capability.split(".") 488 if len(parts) != 2 or not parts[0].isdigit() or not parts[1].isdigit(): 489 auto_configure_fail("Invalid compute capability: %s" % capability) 490 return capabilities 491 492 def get_cpu_value(repository_ctx): 493 """Returns the name of the host operating system. 494 495 Args: 496 repository_ctx: The repository context. 497 498 Returns: 499 A string containing the name of the host operating system. 500 """ 501 os_name = repository_ctx.os.name.lower() 502 if os_name.startswith("mac os"): 503 return "Darwin" 504 if os_name.find("windows") != -1: 505 return "Windows" 506 result = repository_ctx.execute(["uname", "-s"]) 507 return result.stdout.strip() 508 509 def _is_windows(repository_ctx): 510 """Returns true if the host operating system is windows.""" 511 return repository_ctx.os.name.lower().find("windows") >= 0 512 513 def lib_name(base_name, cpu_value, version = None, static = False): 514 """Constructs the platform-specific name of a library. 515 516 Args: 517 base_name: The name of the library, such as "cudart" 518 cpu_value: The name of the host operating system. 519 version: The version of the library. 520 static: True the library is static or False if it is a shared object. 521 522 Returns: 523 The platform-specific name of the library. 524 """ 525 version = "" if not version else "." + version 526 if cpu_value in ("Linux", "FreeBSD"): 527 if static: 528 return "lib%s.a" % base_name 529 return "lib%s.so%s" % (base_name, version) 530 elif cpu_value == "Windows": 531 return "%s.lib" % base_name 532 elif cpu_value == "Darwin": 533 if static: 534 return "lib%s.a" % base_name 535 return "lib%s%s.dylib" % (base_name, version) 536 else: 537 auto_configure_fail("Invalid cpu_value: %s" % cpu_value) 538 539 def find_lib(repository_ctx, paths, check_soname = True): 540 """ 541 Finds a library among a list of potential paths. 542 543 Args: 544 paths: List of paths to inspect. 545 546 Returns: 547 Returns the first path in paths that exist. 548 """ 549 objdump = repository_ctx.which("objdump") 550 mismatches = [] 551 for path in [repository_ctx.path(path) for path in paths]: 552 if not path.exists: 553 continue 554 if check_soname and objdump != None and not _is_windows(repository_ctx): 555 output = repository_ctx.execute([objdump, "-p", str(path)]).stdout 556 output = [line for line in output.splitlines() if "SONAME" in line] 557 sonames = [line.strip().split(" ")[-1] for line in output] 558 if not any([soname == path.basename for soname in sonames]): 559 mismatches.append(str(path)) 560 continue 561 return path 562 if mismatches: 563 auto_configure_fail( 564 "None of the libraries match their SONAME: " + ", ".join(mismatches), 565 ) 566 auto_configure_fail("No library found under: " + ", ".join(paths)) 567 568 def _find_cuda_lib( 569 lib, 570 repository_ctx, 571 cpu_value, 572 basedir, 573 version, 574 static = False): 575 """Finds the given CUDA or cuDNN library on the system. 576 577 Args: 578 lib: The name of the library, such as "cudart" 579 repository_ctx: The repository context. 580 cpu_value: The name of the host operating system. 581 basedir: The install directory of CUDA or cuDNN. 582 version: The version of the library. 583 static: True if static library, False if shared object. 584 585 Returns: 586 Returns the path to the library. 587 """ 588 file_name = lib_name(lib, cpu_value, version, static) 589 return find_lib( 590 repository_ctx, 591 ["%s/%s" % (basedir, file_name)], 592 check_soname = version and not static, 593 ) 594 595 def _find_libs(repository_ctx, cuda_config): 596 """Returns the CUDA and cuDNN libraries on the system. 597 598 Args: 599 repository_ctx: The repository context. 600 cuda_config: The CUDA config as returned by _get_cuda_config 601 602 Returns: 603 Map of library names to structs of filename and path. 604 """ 605 cpu_value = cuda_config.cpu_value 606 stub_dir = "" if _is_windows(repository_ctx) else "/stubs" 607 return { 608 "cuda": _find_cuda_lib( 609 "cuda", 610 repository_ctx, 611 cpu_value, 612 cuda_config.config["cuda_library_dir"] + stub_dir, 613 None, 614 ), 615 "cudart": _find_cuda_lib( 616 "cudart", 617 repository_ctx, 618 cpu_value, 619 cuda_config.config["cuda_library_dir"], 620 cuda_config.cuda_version, 621 ), 622 "cudart_static": _find_cuda_lib( 623 "cudart_static", 624 repository_ctx, 625 cpu_value, 626 cuda_config.config["cuda_library_dir"], 627 cuda_config.cuda_version, 628 static = True, 629 ), 630 "cublas": _find_cuda_lib( 631 "cublas", 632 repository_ctx, 633 cpu_value, 634 cuda_config.config["cublas_library_dir"], 635 cuda_config.cuda_lib_version, 636 ), 637 "cusolver": _find_cuda_lib( 638 "cusolver", 639 repository_ctx, 640 cpu_value, 641 cuda_config.config["cuda_library_dir"], 642 cuda_config.cuda_lib_version, 643 ), 644 "curand": _find_cuda_lib( 645 "curand", 646 repository_ctx, 647 cpu_value, 648 cuda_config.config["cuda_library_dir"], 649 cuda_config.cuda_lib_version, 650 ), 651 "cufft": _find_cuda_lib( 652 "cufft", 653 repository_ctx, 654 cpu_value, 655 cuda_config.config["cuda_library_dir"], 656 cuda_config.cuda_lib_version, 657 ), 658 "cudnn": _find_cuda_lib( 659 "cudnn", 660 repository_ctx, 661 cpu_value, 662 cuda_config.config["cudnn_library_dir"], 663 cuda_config.cudnn_version, 664 ), 665 "cupti": _find_cuda_lib( 666 "cupti", 667 repository_ctx, 668 cpu_value, 669 cuda_config.config["cupti_library_dir"], 670 cuda_config.cuda_version, 671 ), 672 "cusparse": _find_cuda_lib( 673 "cusparse", 674 repository_ctx, 675 cpu_value, 676 cuda_config.config["cuda_library_dir"], 677 cuda_config.cuda_lib_version, 678 ), 679 } 680 681 def _cudart_static_linkopt(cpu_value): 682 """Returns additional platform-specific linkopts for cudart.""" 683 return "" if cpu_value == "Darwin" else "\"-lrt\"," 684 685 # TODO(csigg): Only call once instead of from here, tensorrt_configure.bzl, 686 # and nccl_configure.bzl. 687 def find_cuda_config(repository_ctx, cuda_libraries): 688 """Returns CUDA config dictionary from running find_cuda_config.py""" 689 exec_result = repository_ctx.execute([ 690 _get_python_bin(repository_ctx), 691 repository_ctx.path(Label("@org_tensorflow//third_party/gpus:find_cuda_config.py")), 692 ] + cuda_libraries) 693 if exec_result.return_code: 694 auto_configure_fail("Failed to run find_cuda_config.py: %s" % exec_result.stderr) 695 696 # Parse the dict from stdout. 697 return dict([tuple(x.split(": ")) for x in exec_result.stdout.splitlines()]) 698 699 def _get_cuda_config(repository_ctx): 700 """Detects and returns information about the CUDA installation on the system. 701 702 Args: 703 repository_ctx: The repository context. 704 705 Returns: 706 A struct containing the following fields: 707 cuda_toolkit_path: The CUDA toolkit installation directory. 708 cudnn_install_basedir: The cuDNN installation directory. 709 cuda_version: The version of CUDA on the system. 710 cudnn_version: The version of cuDNN on the system. 711 compute_capabilities: A list of the system's CUDA compute capabilities. 712 cpu_value: The name of the host operating system. 713 """ 714 config = find_cuda_config(repository_ctx, ["cuda", "cudnn"]) 715 cpu_value = get_cpu_value(repository_ctx) 716 toolkit_path = config["cuda_toolkit_path"] 717 718 is_windows = _is_windows(repository_ctx) 719 cuda_version = config["cuda_version"].split(".") 720 cuda_major = cuda_version[0] 721 cuda_minor = cuda_version[1] 722 723 cuda_version = ("64_%s%s" if is_windows else "%s.%s") % (cuda_major, cuda_minor) 724 cudnn_version = ("64_%s" if is_windows else "%s") % config["cudnn_version"] 725 726 # cuda_lib_version is for libraries like cuBLAS, cuFFT, cuSOLVER, etc. 727 # It changed from 'x.y' to just 'x' in CUDA 10.1. 728 if (int(cuda_major), int(cuda_minor)) >= (10, 1): 729 cuda_lib_version = ("64_%s" if is_windows else "%s") % cuda_major 730 else: 731 cuda_lib_version = cuda_version 732 733 return struct( 734 cuda_toolkit_path = toolkit_path, 735 cuda_version = cuda_version, 736 cudnn_version = cudnn_version, 737 cuda_lib_version = cuda_lib_version, 738 compute_capabilities = compute_capabilities(repository_ctx), 739 cpu_value = cpu_value, 740 config = config, 741 ) 742 743 def _tpl(repository_ctx, tpl, substitutions = {}, out = None): 744 if not out: 745 out = tpl.replace(":", "/") 746 repository_ctx.template( 747 out, 748 Label("//third_party/gpus/%s.tpl" % tpl), 749 substitutions, 750 ) 751 752 def _file(repository_ctx, label): 753 repository_ctx.template( 754 label.replace(":", "/"), 755 Label("//third_party/gpus/%s.tpl" % label), 756 {}, 757 ) 758 759 _DUMMY_CROSSTOOL_BZL_FILE = """ 760 def error_gpu_disabled(): 761 fail("ERROR: Building with --config=cuda but TensorFlow is not configured " + 762 "to build with GPU support. Please re-run ./configure and enter 'Y' " + 763 "at the prompt to build with GPU support.") 764 765 native.genrule( 766 name = "error_gen_crosstool", 767 outs = ["CROSSTOOL"], 768 cmd = "echo 'Should not be run.' && exit 1", 769 ) 770 771 native.filegroup( 772 name = "crosstool", 773 srcs = [":CROSSTOOL"], 774 output_licenses = ["unencumbered"], 775 ) 776 """ 777 778 _DUMMY_CROSSTOOL_BUILD_FILE = """ 779 load("//crosstool:error_gpu_disabled.bzl", "error_gpu_disabled") 780 781 error_gpu_disabled() 782 """ 783 784 def _create_dummy_repository(repository_ctx): 785 cpu_value = get_cpu_value(repository_ctx) 786 787 # Set up BUILD file for cuda/. 788 _tpl( 789 repository_ctx, 790 "cuda:build_defs.bzl", 791 { 792 "%{cuda_is_configured}": "False", 793 "%{cuda_extra_copts}": "[]", 794 }, 795 ) 796 _tpl( 797 repository_ctx, 798 "cuda:BUILD", 799 { 800 "%{cuda_driver_lib}": lib_name("cuda", cpu_value), 801 "%{cudart_static_lib}": lib_name( 802 "cudart_static", 803 cpu_value, 804 static = True, 805 ), 806 "%{cudart_static_linkopt}": _cudart_static_linkopt(cpu_value), 807 "%{cudart_lib}": lib_name("cudart", cpu_value), 808 "%{cublas_lib}": lib_name("cublas", cpu_value), 809 "%{cusolver_lib}": lib_name("cusolver", cpu_value), 810 "%{cudnn_lib}": lib_name("cudnn", cpu_value), 811 "%{cufft_lib}": lib_name("cufft", cpu_value), 812 "%{curand_lib}": lib_name("curand", cpu_value), 813 "%{cupti_lib}": lib_name("cupti", cpu_value), 814 "%{cusparse_lib}": lib_name("cusparse", cpu_value), 815 "%{copy_rules}": """ 816 filegroup(name="cuda-include") 817 filegroup(name="cublas-include") 818 filegroup(name="cudnn-include") 819 """, 820 }, 821 ) 822 823 # Create dummy files for the CUDA toolkit since they are still required by 824 # tensorflow/core/platform/default/build_config:cuda. 825 repository_ctx.file("cuda/cuda/include/cuda.h") 826 repository_ctx.file("cuda/cuda/include/cublas.h") 827 repository_ctx.file("cuda/cuda/include/cudnn.h") 828 repository_ctx.file("cuda/cuda/extras/CUPTI/include/cupti.h") 829 repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cuda", cpu_value)) 830 repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cudart", cpu_value)) 831 repository_ctx.file( 832 "cuda/cuda/lib/%s" % lib_name("cudart_static", cpu_value), 833 ) 834 repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cublas", cpu_value)) 835 repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cusolver", cpu_value)) 836 repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cudnn", cpu_value)) 837 repository_ctx.file("cuda/cuda/lib/%s" % lib_name("curand", cpu_value)) 838 repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cufft", cpu_value)) 839 repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cupti", cpu_value)) 840 repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cusparse", cpu_value)) 841 842 # Set up cuda_config.h, which is used by 843 # tensorflow/stream_executor/dso_loader.cc. 844 _tpl( 845 repository_ctx, 846 "cuda:cuda_config.h", 847 { 848 "%{cuda_version}": "", 849 "%{cuda_lib_version}": "", 850 "%{cudnn_version}": "", 851 "%{cuda_compute_capabilities}": ",".join([ 852 "CudaVersion(\"%s\")" % c 853 for c in _DEFAULT_CUDA_COMPUTE_CAPABILITIES 854 ]), 855 "%{cuda_toolkit_path}": "", 856 }, 857 "cuda/cuda/cuda_config.h", 858 ) 859 860 # If cuda_configure is not configured to build with GPU support, and the user 861 # attempts to build with --config=cuda, add a dummy build rule to intercept 862 # this and fail with an actionable error message. 863 repository_ctx.file( 864 "crosstool/error_gpu_disabled.bzl", 865 _DUMMY_CROSSTOOL_BZL_FILE, 866 ) 867 repository_ctx.file("crosstool/BUILD", _DUMMY_CROSSTOOL_BUILD_FILE) 868 869 def _execute( 870 repository_ctx, 871 cmdline, 872 error_msg = None, 873 error_details = None, 874 empty_stdout_fine = False): 875 """Executes an arbitrary shell command. 876 877 Args: 878 repository_ctx: the repository_ctx object 879 cmdline: list of strings, the command to execute 880 error_msg: string, a summary of the error if the command fails 881 error_details: string, details about the error or steps to fix it 882 empty_stdout_fine: bool, if True, an empty stdout result is fine, 883 otherwise it's an error 884 Return: the result of repository_ctx.execute(cmdline) 885 """ 886 result = repository_ctx.execute(cmdline) 887 if result.stderr or not (empty_stdout_fine or result.stdout): 888 auto_configure_fail( 889 "\n".join([ 890 error_msg.strip() if error_msg else "Repository command failed", 891 result.stderr.strip(), 892 error_details if error_details else "", 893 ]), 894 ) 895 return result 896 897 def _norm_path(path): 898 """Returns a path with '/' and remove the trailing slash.""" 899 path = path.replace("\\", "/") 900 if path[-1] == "/": 901 path = path[:-1] 902 return path 903 904 def make_copy_files_rule(repository_ctx, name, srcs, outs): 905 """Returns a rule to copy a set of files.""" 906 cmds = [] 907 908 # Copy files. 909 for src, out in zip(srcs, outs): 910 cmds.append('cp -f "%s" "$(location %s)"' % (src, out)) 911 outs = [(' "%s",' % out) for out in outs] 912 return """genrule( 913 name = "%s", 914 outs = [ 915 %s 916 ], 917 cmd = \"""%s \""", 918 )""" % (name, "\n".join(outs), " && \\\n".join(cmds)) 919 920 def make_copy_dir_rule(repository_ctx, name, src_dir, out_dir): 921 """Returns a rule to recursively copy a directory.""" 922 src_dir = _norm_path(src_dir) 923 out_dir = _norm_path(out_dir) 924 outs = _read_dir(repository_ctx, src_dir) 925 outs = [(' "%s",' % out.replace(src_dir, out_dir)) for out in outs] 926 927 # '@D' already contains the relative path for a single file, see 928 # http://docs.bazel.build/versions/master/be/make-variables.html#predefined_genrule_variables 929 out_dir = "$(@D)/%s" % out_dir if len(outs) > 1 else "$(@D)" 930 return """genrule( 931 name = "%s", 932 outs = [ 933 %s 934 ], 935 cmd = \"""cp -rLf "%s/." "%s/" \""", 936 )""" % (name, "\n".join(outs), src_dir, out_dir) 937 938 def _read_dir(repository_ctx, src_dir): 939 """Returns a string with all files in a directory. 940 941 Finds all files inside a directory, traversing subfolders and following 942 symlinks. The returned string contains the full path of all files 943 separated by line breaks. 944 """ 945 if _is_windows(repository_ctx): 946 src_dir = src_dir.replace("/", "\\") 947 find_result = _execute( 948 repository_ctx, 949 ["cmd.exe", "/c", "dir", src_dir, "/b", "/s", "/a-d"], 950 empty_stdout_fine = True, 951 ) 952 953 # src_files will be used in genrule.outs where the paths must 954 # use forward slashes. 955 result = find_result.stdout.replace("\\", "/") 956 else: 957 find_result = _execute( 958 repository_ctx, 959 ["find", src_dir, "-follow", "-type", "f"], 960 empty_stdout_fine = True, 961 ) 962 result = find_result.stdout 963 return sorted(result.splitlines()) 964 965 def _flag_enabled(repository_ctx, flag_name): 966 if flag_name in repository_ctx.os.environ: 967 value = repository_ctx.os.environ[flag_name].strip() 968 return value == "1" 969 return False 970 971 def _use_cuda_clang(repository_ctx): 972 return _flag_enabled(repository_ctx, "TF_CUDA_CLANG") 973 974 def _compute_cuda_extra_copts(repository_ctx, compute_capabilities): 975 if _use_cuda_clang(repository_ctx): 976 capability_flags = [ 977 "--cuda-gpu-arch=sm_" + cap.replace(".", "") 978 for cap in compute_capabilities 979 ] 980 else: 981 # Capabilities are handled in the "crosstool_wrapper_driver_is_not_gcc" for nvcc 982 # TODO(csigg): Make this consistent with cuda clang and pass to crosstool. 983 capability_flags = [] 984 return str(capability_flags) 985 986 def _create_local_cuda_repository(repository_ctx): 987 """Creates the repository containing files set up to build with CUDA.""" 988 cuda_config = _get_cuda_config(repository_ctx) 989 990 cuda_include_path = cuda_config.config["cuda_include_dir"] 991 cublas_include_path = cuda_config.config["cublas_include_dir"] 992 cudnn_header_dir = cuda_config.config["cudnn_include_dir"] 993 cupti_header_dir = cuda_config.config["cupti_include_dir"] 994 nvvm_libdevice_dir = cuda_config.config["nvvm_library_dir"] 995 996 # Create genrule to copy files from the installed CUDA toolkit into execroot. 997 copy_rules = [ 998 make_copy_dir_rule( 999 repository_ctx, 1000 name = "cuda-include", 1001 src_dir = cuda_include_path, 1002 out_dir = "cuda/include", 1003 ), 1004 make_copy_dir_rule( 1005 repository_ctx, 1006 name = "cuda-nvvm", 1007 src_dir = nvvm_libdevice_dir, 1008 out_dir = "cuda/nvvm/libdevice", 1009 ), 1010 make_copy_dir_rule( 1011 repository_ctx, 1012 name = "cuda-extras", 1013 src_dir = cupti_header_dir, 1014 out_dir = "cuda/extras/CUPTI/include", 1015 ), 1016 ] 1017 1018 copy_rules.append(make_copy_files_rule( 1019 repository_ctx, 1020 name = "cublas-include", 1021 srcs = [ 1022 cublas_include_path + "/cublas.h", 1023 cublas_include_path + "/cublas_v2.h", 1024 cublas_include_path + "/cublas_api.h", 1025 ], 1026 outs = [ 1027 "cublas/include/cublas.h", 1028 "cublas/include/cublas_v2.h", 1029 "cublas/include/cublas_api.h", 1030 ], 1031 )) 1032 1033 cuda_libs = _find_libs(repository_ctx, cuda_config) 1034 cuda_lib_srcs = [] 1035 cuda_lib_outs = [] 1036 for path in cuda_libs.values(): 1037 cuda_lib_srcs.append(str(path)) 1038 cuda_lib_outs.append("cuda/lib/" + path.basename) 1039 copy_rules.append(make_copy_files_rule( 1040 repository_ctx, 1041 name = "cuda-lib", 1042 srcs = cuda_lib_srcs, 1043 outs = cuda_lib_outs, 1044 )) 1045 1046 copy_rules.append(make_copy_dir_rule( 1047 repository_ctx, 1048 name = "cuda-bin", 1049 src_dir = cuda_config.cuda_toolkit_path + "/bin", 1050 out_dir = "cuda/bin", 1051 )) 1052 1053 copy_rules.append(make_copy_files_rule( 1054 repository_ctx, 1055 name = "cudnn-include", 1056 srcs = [cudnn_header_dir + "/cudnn.h"], 1057 outs = ["cudnn/include/cudnn.h"], 1058 )) 1059 1060 # Set up BUILD file for cuda/ 1061 _tpl( 1062 repository_ctx, 1063 "cuda:build_defs.bzl", 1064 { 1065 "%{cuda_is_configured}": "True", 1066 "%{cuda_extra_copts}": _compute_cuda_extra_copts( 1067 repository_ctx, 1068 cuda_config.compute_capabilities, 1069 ), 1070 }, 1071 ) 1072 _tpl( 1073 repository_ctx, 1074 "cuda:BUILD.windows" if _is_windows(repository_ctx) else "cuda:BUILD", 1075 { 1076 "%{cuda_driver_lib}": cuda_libs["cuda"].basename, 1077 "%{cudart_static_lib}": cuda_libs["cudart_static"].basename, 1078 "%{cudart_static_linkopt}": _cudart_static_linkopt(cuda_config.cpu_value), 1079 "%{cudart_lib}": cuda_libs["cudart"].basename, 1080 "%{cublas_lib}": cuda_libs["cublas"].basename, 1081 "%{cusolver_lib}": cuda_libs["cusolver"].basename, 1082 "%{cudnn_lib}": cuda_libs["cudnn"].basename, 1083 "%{cufft_lib}": cuda_libs["cufft"].basename, 1084 "%{curand_lib}": cuda_libs["curand"].basename, 1085 "%{cupti_lib}": cuda_libs["cupti"].basename, 1086 "%{cusparse_lib}": cuda_libs["cusparse"].basename, 1087 "%{copy_rules}": "\n".join(copy_rules), 1088 }, 1089 "cuda/BUILD", 1090 ) 1091 1092 is_cuda_clang = _use_cuda_clang(repository_ctx) 1093 1094 should_download_clang = is_cuda_clang and _flag_enabled( 1095 repository_ctx, 1096 _TF_DOWNLOAD_CLANG, 1097 ) 1098 if should_download_clang: 1099 download_clang(repository_ctx, "crosstool/extra_tools") 1100 1101 # Set up crosstool/ 1102 cc = find_cc(repository_ctx) 1103 cc_fullpath = cc if not should_download_clang else "crosstool/" + cc 1104 1105 host_compiler_includes = get_cxx_inc_directories(repository_ctx, cc_fullpath) 1106 cuda_defines = {} 1107 1108 host_compiler_prefix = "/usr/bin" 1109 if _GCC_HOST_COMPILER_PREFIX in repository_ctx.os.environ: 1110 host_compiler_prefix = repository_ctx.os.environ[_GCC_HOST_COMPILER_PREFIX].strip() 1111 cuda_defines["%{host_compiler_prefix}"] = host_compiler_prefix 1112 1113 # Bazel sets '-B/usr/bin' flag to workaround build errors on RHEL (see 1114 # https://github.com/bazelbuild/bazel/issues/760). 1115 # However, this stops our custom clang toolchain from picking the provided 1116 # LLD linker, so we're only adding '-B/usr/bin' when using non-downloaded 1117 # toolchain. 1118 # TODO: when bazel stops adding '-B/usr/bin' by default, remove this 1119 # flag from the CROSSTOOL completely (see 1120 # https://github.com/bazelbuild/bazel/issues/5634) 1121 if should_download_clang: 1122 cuda_defines["%{linker_bin_path}"] = "" 1123 else: 1124 cuda_defines["%{linker_bin_path}"] = host_compiler_prefix 1125 1126 cuda_defines["%{extra_no_canonical_prefixes_flags}"] = "" 1127 cuda_defines["%{unfiltered_compile_flags}"] = "" 1128 if is_cuda_clang: 1129 cuda_defines["%{host_compiler_path}"] = str(cc) 1130 cuda_defines["%{host_compiler_warnings}"] = """ 1131 # Some parts of the codebase set -Werror and hit this warning, so 1132 # switch it off for now. 1133 "-Wno-invalid-partial-specialization" 1134 """ 1135 cuda_defines["%{cxx_builtin_include_directories}"] = to_list_of_strings(host_compiler_includes) 1136 cuda_defines["%{linker_files}"] = ":empty" 1137 cuda_defines["%{win_linker_files}"] = ":empty" 1138 repository_ctx.file( 1139 "crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc", 1140 "", 1141 ) 1142 repository_ctx.file("crosstool/windows/msvc_wrapper_for_nvcc.py", "") 1143 else: 1144 cuda_defines["%{host_compiler_path}"] = "clang/bin/crosstool_wrapper_driver_is_not_gcc" 1145 cuda_defines["%{host_compiler_warnings}"] = "" 1146 1147 # nvcc has the system include paths built in and will automatically 1148 # search them; we cannot work around that, so we add the relevant cuda 1149 # system paths to the allowed compiler specific include paths. 1150 cuda_defines["%{cxx_builtin_include_directories}"] = to_list_of_strings( 1151 host_compiler_includes + _cuda_include_path( 1152 repository_ctx, 1153 cuda_config, 1154 ) + [cupti_header_dir, cudnn_header_dir], 1155 ) 1156 1157 # For gcc, do not canonicalize system header paths; some versions of gcc 1158 # pick the shortest possible path for system includes when creating the 1159 # .d file - given that includes that are prefixed with "../" multiple 1160 # time quickly grow longer than the root of the tree, this can lead to 1161 # bazel's header check failing. 1162 cuda_defines["%{extra_no_canonical_prefixes_flags}"] = "\"-fno-canonical-system-headers\"" 1163 1164 nvcc_path = str( 1165 repository_ctx.path("%s/nvcc%s" % ( 1166 cuda_config.config["cuda_binary_dir"], 1167 ".exe" if _is_windows(repository_ctx) else "", 1168 )), 1169 ) 1170 cuda_defines["%{linker_files}"] = ":crosstool_wrapper_driver_is_not_gcc" 1171 cuda_defines["%{win_linker_files}"] = ":windows_msvc_wrapper_files" 1172 1173 wrapper_defines = { 1174 "%{cpu_compiler}": str(cc), 1175 "%{cuda_version}": cuda_config.cuda_version, 1176 "%{nvcc_path}": nvcc_path, 1177 "%{gcc_host_compiler_path}": str(cc), 1178 "%{cuda_compute_capabilities}": ", ".join( 1179 ["\"%s\"" % c for c in cuda_config.compute_capabilities], 1180 ), 1181 "%{nvcc_tmp_dir}": _get_nvcc_tmp_dir_for_windows(repository_ctx), 1182 } 1183 _tpl( 1184 repository_ctx, 1185 "crosstool:clang/bin/crosstool_wrapper_driver_is_not_gcc", 1186 wrapper_defines, 1187 ) 1188 _tpl( 1189 repository_ctx, 1190 "crosstool:windows/msvc_wrapper_for_nvcc.py", 1191 wrapper_defines, 1192 ) 1193 1194 cuda_defines.update(_get_win_cuda_defines(repository_ctx)) 1195 1196 verify_build_defines(cuda_defines) 1197 1198 # Only expand template variables in the BUILD file 1199 _tpl(repository_ctx, "crosstool:BUILD", cuda_defines) 1200 1201 # No templating of cc_toolchain_config - use attributes and templatize the 1202 # BUILD file. 1203 _file(repository_ctx, "crosstool:cc_toolchain_config.bzl") 1204 1205 # Set up cuda_config.h, which is used by 1206 # tensorflow/stream_executor/dso_loader.cc. 1207 _tpl( 1208 repository_ctx, 1209 "cuda:cuda_config.h", 1210 { 1211 "%{cuda_version}": cuda_config.cuda_version, 1212 "%{cuda_lib_version}": cuda_config.cuda_lib_version, 1213 "%{cudnn_version}": cuda_config.cudnn_version, 1214 "%{cuda_compute_capabilities}": ", ".join([ 1215 "CudaVersion(\"%s\")" % c 1216 for c in cuda_config.compute_capabilities 1217 ]), 1218 "%{cuda_toolkit_path}": cuda_config.cuda_toolkit_path, 1219 }, 1220 "cuda/cuda/cuda_config.h", 1221 ) 1222 1223 def _create_remote_cuda_repository(repository_ctx, remote_config_repo): 1224 """Creates pointers to a remotely configured repo set up to build with CUDA.""" 1225 _tpl( 1226 repository_ctx, 1227 "cuda:build_defs.bzl", 1228 { 1229 "%{cuda_is_configured}": "True", 1230 "%{cuda_extra_copts}": _compute_cuda_extra_copts( 1231 repository_ctx, 1232 compute_capabilities(repository_ctx), 1233 ), 1234 }, 1235 ) 1236 repository_ctx.template( 1237 "cuda/BUILD", 1238 Label(remote_config_repo + "/cuda:BUILD"), 1239 {}, 1240 ) 1241 repository_ctx.template( 1242 "cuda/build_defs.bzl", 1243 Label(remote_config_repo + "/cuda:build_defs.bzl"), 1244 {}, 1245 ) 1246 repository_ctx.template( 1247 "cuda/cuda/cuda_config.h", 1248 Label(remote_config_repo + "/cuda:cuda/cuda_config.h"), 1249 {}, 1250 ) 1251 1252 def _cuda_autoconf_impl(repository_ctx): 1253 """Implementation of the cuda_autoconf repository rule.""" 1254 if not enable_cuda(repository_ctx): 1255 _create_dummy_repository(repository_ctx) 1256 elif _TF_CUDA_CONFIG_REPO in repository_ctx.os.environ: 1257 if (_TF_CUDA_VERSION not in repository_ctx.os.environ or 1258 _TF_CUDNN_VERSION not in repository_ctx.os.environ): 1259 auto_configure_fail("%s and %s must also be set if %s is specified" % 1260 (_TF_CUDA_VERSION, _TF_CUDNN_VERSION, _TF_CUDA_CONFIG_REPO)) 1261 _create_remote_cuda_repository( 1262 repository_ctx, 1263 repository_ctx.os.environ[_TF_CUDA_CONFIG_REPO], 1264 ) 1265 else: 1266 _create_local_cuda_repository(repository_ctx) 1267 1268 cuda_configure = repository_rule( 1269 implementation = _cuda_autoconf_impl, 1270 environ = [ 1271 _GCC_HOST_COMPILER_PATH, 1272 _GCC_HOST_COMPILER_PREFIX, 1273 _CLANG_CUDA_COMPILER_PATH, 1274 "TF_NEED_CUDA", 1275 "TF_CUDA_CLANG", 1276 _TF_DOWNLOAD_CLANG, 1277 _CUDA_TOOLKIT_PATH, 1278 _CUDNN_INSTALL_PATH, 1279 _TF_CUDA_VERSION, 1280 _TF_CUDNN_VERSION, 1281 _TF_CUDA_COMPUTE_CAPABILITIES, 1282 _TF_CUDA_CONFIG_REPO, 1283 "NVVMIR_LIBRARY_DIR", 1284 _PYTHON_BIN_PATH, 1285 "TMP", 1286 "TMPDIR", 1287 "TF_CUDA_PATHS", 1288 ], 1289 ) 1290 1291 """Detects and configures the local CUDA toolchain. 1292 1293 Add the following to your WORKSPACE FILE: 1294 1295 ```python 1296 cuda_configure(name = "local_config_cuda") 1297 ``` 1298 1299 Args: 1300 name: A unique name for this workspace rule. 1301 """