github.com/rohankumardubey/aresdb@v0.0.2-0.20190517170215-e54e3ca06b9c/query/filter.cu (about)

     1  //  Copyright (c) 2017-2018 Uber Technologies, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  #include <algorithm>
    16  #include <cstdio>
    17  #include <cstring>
    18  #include <exception>
    19  #include <vector>
    20  #include <initializer_list>
    21  #include "query/transform.hpp"
    22  #include "query/binder.hpp"
    23  
    24  namespace ares {
    25  
    26  // FilterContext is doing the actual filter after binding one or two
    27  // input iterators.
    28  template<typename FunctorType>
    29  class FilterContext {
    30   public:
    31    FilterContext(
    32        uint8_t *predicateVector, int indexVectorLength,
    33        RecordID **foreignTableRecordIDVectors,
    34        int numForeignTables, FunctorType functorType,
    35        void *cudaStream)
    36        : predicateVector(predicateVector),
    37          indexVectorLength(indexVectorLength),
    38          foreignTableRecordIDVectors(foreignTableRecordIDVectors),
    39          numForeignTables(numForeignTables),
    40          functorType(functorType),
    41          cudaStream(reinterpret_cast<cudaStream_t>(cudaStream)) {}
    42  
    43    cudaStream_t getStream() const {
    44      return cudaStream;
    45    }
    46  
    47    template<typename InputIterator>
    48    int run(uint32_t *indexVector, InputIterator inputIterator);
    49  
    50    template<typename LHSIterator, typename RHSIterator>
    51    int run(uint32_t *indexVector, LHSIterator lhsIter, RHSIterator rhsIter);
    52  
    53   private:
    54    uint8_t *predicateVector;
    55    int indexVectorLength;
    56    RecordID **foreignTableRecordIDVectors;
    57    int numForeignTables;
    58    FunctorType functorType;
    59    cudaStream_t cudaStream;
    60  
    61    template<typename LHSIterator, typename RHSIterator,
    62        typename IndexZipIterator>
    63    int executeRemoveIf(LHSIterator lhsIter,
    64                        RHSIterator rhsIter,
    65                        IndexZipIterator indexZipIterator);
    66  
    67    template<typename InputIterator, typename IndexZipIterator>
    68    int executeRemoveIf(InputIterator inputIter,
    69                        IndexZipIterator indexZipIterator);
    70  };
    71  
    72  }  // namespace ares
    73  
    74  CGoCallResHandle UnaryFilter(InputVector input,
    75                               uint32_t *indexVector,
    76                               uint8_t *predicateVector,
    77                               int indexVectorLength,
    78                               RecordID **foreignTableRecordIDVectors,
    79                               int numForeignTables,
    80                               uint32_t *baseCounts,
    81                               uint32_t startCount,
    82                               UnaryFunctorType functorType,
    83                               void *cudaStream,
    84                               int device) {
    85    CGoCallResHandle resHandle = {nullptr, nullptr};
    86    try {
    87  #ifdef RUN_ON_DEVICE
    88      cudaSetDevice(device);
    89  #endif
    90      ares::FilterContext<UnaryFunctorType> ctx(predicateVector,
    91                                                indexVectorLength,
    92                                                foreignTableRecordIDVectors,
    93                                                numForeignTables,
    94                                                functorType,
    95                                                cudaStream);
    96      std::vector<InputVector> inputVectors = {input};
    97      ares::InputVectorBinder<ares::FilterContext<UnaryFunctorType>, 1>
    98          binder(ctx, inputVectors, indexVector, baseCounts, startCount);
    99      resHandle.res =
   100          reinterpret_cast<void *>(binder.bind());
   101      CheckCUDAError("UnaryFilter");
   102    }
   103    catch (std::exception &e) {
   104      std::cerr << "Exception happend when doing UnaryFilter:" << e.what()
   105                << std::endl;
   106      resHandle.pStrErr = strdup(e.what());
   107    }
   108    return resHandle;
   109  }
   110  
   111  CGoCallResHandle BinaryFilter(InputVector lhs,
   112                                InputVector rhs,
   113                                uint32_t *indexVector,
   114                                uint8_t *predicateVector,
   115                                int indexVectorLength,
   116                                RecordID **foreignTableRecordIDVectors,
   117                                int numForeignTables,
   118                                uint32_t *baseCounts,
   119                                uint32_t startCount,
   120                                BinaryFunctorType functorType,
   121                                void *cudaStream,
   122                                int device) {
   123    CGoCallResHandle resHandle = {nullptr, nullptr};
   124    try {
   125  #ifdef RUN_ON_DEVICE
   126      cudaSetDevice(device);
   127  #endif
   128      ares::FilterContext<BinaryFunctorType> ctx(predicateVector,
   129                                                 indexVectorLength,
   130                                                 foreignTableRecordIDVectors,
   131                                                 numForeignTables,
   132                                                 functorType,
   133                                                 cudaStream);
   134      std::vector<InputVector> inputVectors = {lhs, rhs};
   135      ares::InputVectorBinder<ares::FilterContext<BinaryFunctorType>, 2> binder(
   136          ctx, inputVectors, indexVector, baseCounts, startCount);
   137  
   138      resHandle.res =
   139          reinterpret_cast<void *>(binder.bind());
   140      CheckCUDAError("BinaryFilter");
   141    }
   142    catch (std::exception &e) {
   143      std::cerr << "Exception happend when doing BinaryFilter:" << e.what()
   144                << std::endl;
   145      resHandle.pStrErr = strdup(e.what());
   146    }
   147    return resHandle;
   148  }
   149  
   150  namespace ares {
   151  
   152  // Filter template function for unary transform filter.
   153  template<typename FunctorType>
   154  template<typename InputIterator, typename IndexZipIterator>
   155  int FilterContext<FunctorType>::executeRemoveIf(
   156      InputIterator inputIter,
   157      IndexZipIterator indexZipIterator) {
   158    typedef typename InputIterator::value_type::head_type InputValueType;
   159    UnaryPredicateFunctor<bool, InputValueType> f(functorType);
   160    RemoveFilter<typename IndexZipIterator::value_type, uint8_t> removeFilter(
   161        predicateVector);
   162    // first compute the predicate values.
   163    thrust::transform(GET_EXECUTION_POLICY(cudaStream), inputIter,
   164                      inputIter + indexVectorLength, predicateVector, f);
   165    // then we use the predicate values to remove indexes in place.
   166    return thrust::remove_if(GET_EXECUTION_POLICY(cudaStream), indexZipIterator,
   167                             indexZipIterator + indexVectorLength, removeFilter) -
   168           indexZipIterator;
   169  }
   170  
   171  // run unary filter.
   172  template<typename FunctorType>
   173  template<typename InputIterator>
   174  int FilterContext<FunctorType>::run(uint32_t *indexVector,
   175                                      InputIterator inputIterator) {
   176    switch (numForeignTables) {
   177      #define EXECUTE_UNARY_REMOVE_IF(NumTotalForeignTables) \
   178      case NumTotalForeignTables: { \
   179        IndexZipIteratorMaker<NumTotalForeignTables> maker; \
   180        return executeRemoveIf(inputIterator, \
   181                             maker.make(indexVector, \
   182                                        foreignTableRecordIDVectors)); \
   183      }
   184  
   185      EXECUTE_UNARY_REMOVE_IF(0)
   186      EXECUTE_UNARY_REMOVE_IF(1)
   187      EXECUTE_UNARY_REMOVE_IF(2)
   188      EXECUTE_UNARY_REMOVE_IF(3)
   189      EXECUTE_UNARY_REMOVE_IF(4)
   190      EXECUTE_UNARY_REMOVE_IF(5)
   191      EXECUTE_UNARY_REMOVE_IF(6)
   192      EXECUTE_UNARY_REMOVE_IF(7)
   193      EXECUTE_UNARY_REMOVE_IF(8)
   194      default:throw std::invalid_argument("only support up to 8 foreign tables");
   195    }
   196  }
   197  
   198  // run binary filter.
   199  template<typename FunctorType>
   200  template<typename LHSIterator, typename RHSIterator, typename IndexZipIterator>
   201  int FilterContext<FunctorType>::executeRemoveIf(
   202      LHSIterator lhsIter,
   203      RHSIterator rhsIter,
   204      IndexZipIterator indexZipIterator) {
   205    typedef typename common_type<
   206        typename LHSIterator::value_type::head_type,
   207        typename RHSIterator::value_type::head_type>::type InputValueType;
   208    BinaryPredicateFunctor<bool, InputValueType> f(functorType);
   209    RemoveFilter<typename IndexZipIterator::value_type, uint8_t> removeFilter(
   210        predicateVector);
   211  
   212    // first compute the predicate values.
   213    thrust::transform(GET_EXECUTION_POLICY(cudaStream), lhsIter,
   214        lhsIter + indexVectorLength, rhsIter, predicateVector, f);
   215    // then we use the predicate values to remove indexes in place.
   216    return thrust::remove_if(GET_EXECUTION_POLICY(cudaStream), indexZipIterator,
   217                             indexZipIterator + indexVectorLength, removeFilter) -
   218           indexZipIterator;
   219  }
   220  
   221  // template partial specialization with output iterator as uint8_t* for binary
   222  // transform.
   223  template<typename FunctorType>
   224  template<typename LHSIterator, typename RHSIterator>
   225  int FilterContext<FunctorType>::run(uint32_t *indexVector,
   226                                      LHSIterator lhsIter,
   227                                      RHSIterator rhsIter) {
   228    switch (numForeignTables) {
   229      #define EXECUTE_BINARY_REMOVE_IF(NumTotalForeignTables) \
   230      case NumTotalForeignTables: { \
   231        IndexZipIteratorMaker<NumTotalForeignTables> maker; \
   232        return executeRemoveIf(lhsIter, rhsIter, maker.make(indexVector, \
   233                                 foreignTableRecordIDVectors)); \
   234      }
   235  
   236      EXECUTE_BINARY_REMOVE_IF(0)
   237      EXECUTE_BINARY_REMOVE_IF(1)
   238      EXECUTE_BINARY_REMOVE_IF(2)
   239      EXECUTE_BINARY_REMOVE_IF(3)
   240      EXECUTE_BINARY_REMOVE_IF(4)
   241      EXECUTE_BINARY_REMOVE_IF(5)
   242      EXECUTE_BINARY_REMOVE_IF(6)
   243      EXECUTE_BINARY_REMOVE_IF(7)
   244      EXECUTE_BINARY_REMOVE_IF(8)
   245      default:throw std::invalid_argument("only support up to 8 foreign tables");
   246    }
   247  }
   248  
   249  }  // namespace ares