github.com/rohankumardubey/aresdb@v0.0.2-0.20190517170215-e54e3ca06b9c/query/transform.cu (about)

     1  //  Copyright (c) 2017-2018 Uber Technologies, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  #include <cstring>
    16  #include <algorithm>
    17  #include <exception>
    18  #include <vector>
    19  #include "query/transform.hpp"
    20  
    21  CGoCallResHandle UnaryTransform(InputVector input,
    22                                  OutputVector output,
    23                                  uint32_t *indexVector,
    24                                  int indexVectorLength,
    25                                  uint32_t *baseCounts,
    26                                  uint32_t startCount,
    27                                  UnaryFunctorType functorType,
    28                                  void *cudaStream,
    29                                  int device) {
    30    CGoCallResHandle resHandle = {nullptr, nullptr};
    31    try {
    32  #ifdef RUN_ON_DEVICE
    33      cudaSetDevice(device);
    34  #endif
    35      std::vector<InputVector> inputVectors = {input};
    36      ares::OutputVectorBinder<1, UnaryFunctorType> outputVectorBinder(output,
    37                                                             inputVectors,
    38                                                             indexVector,
    39                                                             indexVectorLength,
    40                                                             baseCounts,
    41                                                             startCount,
    42                                                             functorType,
    43                                                             cudaStream);
    44      resHandle.res = reinterpret_cast<void *>(outputVectorBinder.bind());
    45      CheckCUDAError("UnaryTransform");
    46    } catch (std::exception &e) {
    47      std::cerr << "Exception happend when doing UnaryTransform:" << e.what()
    48                << std::endl;
    49      resHandle.pStrErr = strdup(e.what());
    50    }
    51    return resHandle;
    52  }
    53  
    54  CGoCallResHandle BinaryTransform(InputVector lhs,
    55                                   InputVector rhs,
    56                                   OutputVector output,
    57                                   uint32_t *indexVector,
    58                                   int indexVectorLength,
    59                                   uint32_t *baseCounts,
    60                                   uint32_t startCount,
    61                                   BinaryFunctorType functorType,
    62                                   void *cudaStream,
    63                                   int device) {
    64    CGoCallResHandle resHandle = {nullptr, nullptr};
    65    try {
    66  #ifdef RUN_ON_DEVICE
    67      cudaSetDevice(device);
    68  #endif
    69      std::vector<InputVector> inputVectors = {lhs, rhs};
    70      ares::OutputVectorBinder<2, BinaryFunctorType> outputVectorBinder(output,
    71                                                              inputVectors,
    72                                                              indexVector,
    73                                                              indexVectorLength,
    74                                                              baseCounts,
    75                                                              startCount,
    76                                                              functorType,
    77                                                              cudaStream);
    78      resHandle.res = reinterpret_cast<void *>(outputVectorBinder.bind());
    79      CheckCUDAError("BinaryTransform");
    80    } catch (std::exception &e) {
    81      std::cerr << "Exception happend when doing BinaryTransform:" << e.what()
    82                << std::endl;
    83      resHandle.pStrErr = strdup(e.what());
    84    }
    85    return resHandle;
    86  }