github.com/rohankumardubey/aresdb@v0.0.2-0.20190517170215-e54e3ca06b9c/query/thrust_rmm_allocator.hpp (about)

     1  /*
     2   * Copyright (c) 2018, NVIDIA CORPORATION.
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  /**
    18   Allocator class compatible with thrust arrays that uses RMM device memory manager.
    19   Author: Mark Harris
    20   */
    21  
    22  #ifndef QUERY_THRUST_RMM_ALLOCATOR_HPP_
    23  #define QUERY_THRUST_RMM_ALLOCATOR_HPP_
    24  
    25  #include <rmm/rmm.h>
    26  #include <thrust/device_vector.h>
    27  #include <thrust/device_malloc_allocator.h>
    28  #include <thrust/system_error.h>
    29  #include <thrust/system/cuda/error.h>
    30  #include <thrust/execution_policy.h>
    31  #include <memory>
    32  
    33  template<class T>
    34  class rmm_allocator : public thrust::device_malloc_allocator<T> {
    35   public:
    36    using value_type = T;
    37  
    38    explicit rmm_allocator(cudaStream_t stream = 0) : stream(stream) {}
    39    ~rmm_allocator() {
    40    }
    41  
    42    typedef thrust::device_ptr<value_type> pointer;
    43    inline pointer allocate(size_t n) {
    44      value_type *result = nullptr;
    45  
    46      rmmError_t error = RMM_ALLOC((void **) &result, n * sizeof(value_type),
    47                                   stream);
    48  
    49      if (error != RMM_SUCCESS) {
    50        throw thrust::system_error(error, thrust::cuda_category(),
    51                                   "rmm_allocator::allocate(): RMM_ALLOC");
    52      }
    53  
    54      return thrust::device_pointer_cast(result);
    55    }
    56  
    57    inline void deallocate(pointer ptr, size_t) {
    58      rmmError_t error = RMM_FREE(thrust::raw_pointer_cast(ptr), stream);
    59  
    60      if (error != RMM_SUCCESS) {
    61        throw thrust::system_error(error, thrust::cuda_category(),
    62                                   "rmm_allocator::deallocate(): RMM_FREE");
    63      }
    64    }
    65  
    66   private:
    67    cudaStream_t stream;
    68  };
    69  
    70  namespace rmm {
    71  /**
    72   * @brief Alias for a thrust::device_vector that uses RMM for memory allocation.
    73   *
    74   */
    75  template<typename T>
    76  using device_vector = thrust::device_vector<T, rmm_allocator<T>>;
    77  
    78  using par_t = decltype(thrust::cuda::par(*(new rmm_allocator<char>(0))));
    79  using deleter_t = std::function<void(par_t *)>;
    80  using exec_policy_t = std::unique_ptr<par_t, deleter_t>;
    81  
    82  /* --------------------------------------------------------------------------*/
    83  /**
    84   * @brief Returns a unique_ptr to a Thrust CUDA execution policy that uses RMM
    85   * for temporary memory allocation.
    86   *
    87   * @Param stream The stream that the allocator will use
    88   *
    89   * @Returns A Thrust execution policy that will use RMM for temporary memory
    90   * allocation.
    91   */
    92  /* --------------------------------------------------------------------------*/
    93  inline exec_policy_t exec_policy(cudaStream_t stream = 0) {
    94    rmm_allocator<char> *alloc{nullptr};
    95    alloc = new rmm_allocator<char>(stream);
    96    auto deleter = [alloc](par_t *pointer) {
    97      delete alloc;
    98      delete pointer;
    99    };
   100  
   101    exec_policy_t policy{new par_t(*alloc), deleter};
   102    return policy;
   103  }
   104  
   105  }  // namespace rmm
   106  
   107  #endif  // QUERY_THRUST_RMM_ALLOCATOR_HPP_