github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp (about)

     1  //===- cuda-runtime-wrappers.cpp - MLIR CUDA runner wrapper library -------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  //
    18  // Implements C wrappers around the CUDA library for easy linking in ORC jit.
    19  // Also adds some debugging helpers that are helpful when writing MLIR code to
    20  // run on GPUs.
    21  //
    22  //===----------------------------------------------------------------------===//
    23  
    24  #include <assert.h>
    25  #include <memory.h>
    26  
    27  #include "llvm/Support/raw_ostream.h"
    28  
    29  #include "cuda.h"
    30  
    31  namespace {
    32  int32_t reportErrorIfAny(CUresult result, const char *where) {
    33    if (result != CUDA_SUCCESS) {
    34      llvm::errs() << "CUDA failed with " << result << " in " << where << "\n";
    35    }
    36    return result;
    37  }
    38  } // anonymous namespace
    39  
    40  extern "C" int32_t mcuModuleLoad(void **module, void *data) {
    41    int32_t err = reportErrorIfAny(
    42        cuModuleLoadData(reinterpret_cast<CUmodule *>(module), data),
    43        "ModuleLoad");
    44    return err;
    45  }
    46  
    47  extern "C" int32_t mcuModuleGetFunction(void **function, void *module,
    48                                          const char *name) {
    49    return reportErrorIfAny(
    50        cuModuleGetFunction(reinterpret_cast<CUfunction *>(function),
    51                            reinterpret_cast<CUmodule>(module), name),
    52        "GetFunction");
    53  }
    54  
    55  // The wrapper uses intptr_t instead of CUDA's unsigned int to match
    56  // the type of MLIR's index type. This avoids the need for casts in the
    57  // generated MLIR code.
    58  extern "C" int32_t mcuLaunchKernel(void *function, intptr_t gridX,
    59                                     intptr_t gridY, intptr_t gridZ,
    60                                     intptr_t blockX, intptr_t blockY,
    61                                     intptr_t blockZ, int32_t smem, void *stream,
    62                                     void **params, void **extra) {
    63    return reportErrorIfAny(
    64        cuLaunchKernel(reinterpret_cast<CUfunction>(function), gridX, gridY,
    65                       gridZ, blockX, blockY, blockZ, smem,
    66                       reinterpret_cast<CUstream>(stream), params, extra),
    67        "LaunchKernel");
    68  }
    69  
    70  extern "C" void *mcuGetStreamHelper() {
    71    CUstream stream;
    72    reportErrorIfAny(cuStreamCreate(&stream, CU_STREAM_DEFAULT), "StreamCreate");
    73    return stream;
    74  }
    75  
    76  extern "C" int32_t mcuStreamSynchronize(void *stream) {
    77    return reportErrorIfAny(
    78        cuStreamSynchronize(reinterpret_cast<CUstream>(stream)), "StreamSync");
    79  }
    80  
    81  /// Helper functions for writing mlir example code
    82  
    83  // A struct that corresponds to how MLIR represents unknown-length 1d memrefs.
    84  struct memref_t {
    85    float *values;
    86    intptr_t length;
    87  };
    88  
    89  // Allows to register a pointer with the CUDA runtime. Helpful until
    90  // we have transfer functions implemented.
    91  extern "C" void mcuMemHostRegister(const memref_t arg, int32_t flags) {
    92    reportErrorIfAny(
    93        cuMemHostRegister(arg.values, arg.length * sizeof(float), flags),
    94        "MemHostRegister");
    95  }
    96  
    97  /// Prints the given float array to stderr.
    98  extern "C" void mcuPrintFloat(const memref_t arg) {
    99    if (arg.length == 0) {
   100      llvm::outs() << "[]\n";
   101      return;
   102    }
   103    llvm::outs() << "[" << arg.values[0];
   104    for (int pos = 1; pos < arg.length; pos++) {
   105      llvm::outs() << ", " << arg.values[pos];
   106    }
   107    llvm::outs() << "]\n";
   108  }