github.com/rohankumardubey/aresdb@v0.0.2-0.20190517170215-e54e3ca06b9c/query/filter.cu (about) 1 // Copyright (c) 2017-2018 Uber Technologies, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #include <algorithm> 16 #include <cstdio> 17 #include <cstring> 18 #include <exception> 19 #include <vector> 20 #include <initializer_list> 21 #include "query/transform.hpp" 22 #include "query/binder.hpp" 23 24 namespace ares { 25 26 // FilterContext is doing the actual filter after binding one or two 27 // input iterators. 28 template<typename FunctorType> 29 class FilterContext { 30 public: 31 FilterContext( 32 uint8_t *predicateVector, int indexVectorLength, 33 RecordID **foreignTableRecordIDVectors, 34 int numForeignTables, FunctorType functorType, 35 void *cudaStream) 36 : predicateVector(predicateVector), 37 indexVectorLength(indexVectorLength), 38 foreignTableRecordIDVectors(foreignTableRecordIDVectors), 39 numForeignTables(numForeignTables), 40 functorType(functorType), 41 cudaStream(reinterpret_cast<cudaStream_t>(cudaStream)) {} 42 43 cudaStream_t getStream() const { 44 return cudaStream; 45 } 46 47 template<typename InputIterator> 48 int run(uint32_t *indexVector, InputIterator inputIterator); 49 50 template<typename LHSIterator, typename RHSIterator> 51 int run(uint32_t *indexVector, LHSIterator lhsIter, RHSIterator rhsIter); 52 53 private: 54 uint8_t *predicateVector; 55 int indexVectorLength; 56 RecordID **foreignTableRecordIDVectors; 57 int numForeignTables; 58 FunctorType functorType; 59 cudaStream_t cudaStream; 60 61 template<typename LHSIterator, typename RHSIterator, 62 typename IndexZipIterator> 63 int executeRemoveIf(LHSIterator lhsIter, 64 RHSIterator rhsIter, 65 IndexZipIterator indexZipIterator); 66 67 template<typename InputIterator, typename IndexZipIterator> 68 int executeRemoveIf(InputIterator inputIter, 69 IndexZipIterator indexZipIterator); 70 }; 71 72 } // namespace ares 73 74 CGoCallResHandle UnaryFilter(InputVector input, 75 uint32_t *indexVector, 76 uint8_t *predicateVector, 77 int indexVectorLength, 78 RecordID **foreignTableRecordIDVectors, 79 int numForeignTables, 80 uint32_t *baseCounts, 81 uint32_t startCount, 82 UnaryFunctorType functorType, 83 void *cudaStream, 84 int device) { 85 CGoCallResHandle resHandle = {nullptr, nullptr}; 86 try { 87 #ifdef RUN_ON_DEVICE 88 cudaSetDevice(device); 89 #endif 90 ares::FilterContext<UnaryFunctorType> ctx(predicateVector, 91 indexVectorLength, 92 foreignTableRecordIDVectors, 93 numForeignTables, 94 functorType, 95 cudaStream); 96 std::vector<InputVector> inputVectors = {input}; 97 ares::InputVectorBinder<ares::FilterContext<UnaryFunctorType>, 1> 98 binder(ctx, inputVectors, indexVector, baseCounts, startCount); 99 resHandle.res = 100 reinterpret_cast<void *>(binder.bind()); 101 CheckCUDAError("UnaryFilter"); 102 } 103 catch (std::exception &e) { 104 std::cerr << "Exception happend when doing UnaryFilter:" << e.what() 105 << std::endl; 106 resHandle.pStrErr = strdup(e.what()); 107 } 108 return resHandle; 109 } 110 111 CGoCallResHandle BinaryFilter(InputVector lhs, 112 InputVector rhs, 113 uint32_t *indexVector, 114 uint8_t *predicateVector, 115 int indexVectorLength, 116 RecordID **foreignTableRecordIDVectors, 117 int numForeignTables, 118 uint32_t *baseCounts, 119 uint32_t startCount, 120 BinaryFunctorType functorType, 121 void *cudaStream, 122 int device) { 123 CGoCallResHandle resHandle = {nullptr, nullptr}; 124 try { 125 #ifdef RUN_ON_DEVICE 126 cudaSetDevice(device); 127 #endif 128 ares::FilterContext<BinaryFunctorType> ctx(predicateVector, 129 indexVectorLength, 130 foreignTableRecordIDVectors, 131 numForeignTables, 132 functorType, 133 cudaStream); 134 std::vector<InputVector> inputVectors = {lhs, rhs}; 135 ares::InputVectorBinder<ares::FilterContext<BinaryFunctorType>, 2> binder( 136 ctx, inputVectors, indexVector, baseCounts, startCount); 137 138 resHandle.res = 139 reinterpret_cast<void *>(binder.bind()); 140 CheckCUDAError("BinaryFilter"); 141 } 142 catch (std::exception &e) { 143 std::cerr << "Exception happend when doing BinaryFilter:" << e.what() 144 << std::endl; 145 resHandle.pStrErr = strdup(e.what()); 146 } 147 return resHandle; 148 } 149 150 namespace ares { 151 152 // Filter template function for unary transform filter. 153 template<typename FunctorType> 154 template<typename InputIterator, typename IndexZipIterator> 155 int FilterContext<FunctorType>::executeRemoveIf( 156 InputIterator inputIter, 157 IndexZipIterator indexZipIterator) { 158 typedef typename InputIterator::value_type::head_type InputValueType; 159 UnaryPredicateFunctor<bool, InputValueType> f(functorType); 160 RemoveFilter<typename IndexZipIterator::value_type, uint8_t> removeFilter( 161 predicateVector); 162 // first compute the predicate values. 163 thrust::transform(GET_EXECUTION_POLICY(cudaStream), inputIter, 164 inputIter + indexVectorLength, predicateVector, f); 165 // then we use the predicate values to remove indexes in place. 166 return thrust::remove_if(GET_EXECUTION_POLICY(cudaStream), indexZipIterator, 167 indexZipIterator + indexVectorLength, removeFilter) - 168 indexZipIterator; 169 } 170 171 // run unary filter. 172 template<typename FunctorType> 173 template<typename InputIterator> 174 int FilterContext<FunctorType>::run(uint32_t *indexVector, 175 InputIterator inputIterator) { 176 switch (numForeignTables) { 177 #define EXECUTE_UNARY_REMOVE_IF(NumTotalForeignTables) \ 178 case NumTotalForeignTables: { \ 179 IndexZipIteratorMaker<NumTotalForeignTables> maker; \ 180 return executeRemoveIf(inputIterator, \ 181 maker.make(indexVector, \ 182 foreignTableRecordIDVectors)); \ 183 } 184 185 EXECUTE_UNARY_REMOVE_IF(0) 186 EXECUTE_UNARY_REMOVE_IF(1) 187 EXECUTE_UNARY_REMOVE_IF(2) 188 EXECUTE_UNARY_REMOVE_IF(3) 189 EXECUTE_UNARY_REMOVE_IF(4) 190 EXECUTE_UNARY_REMOVE_IF(5) 191 EXECUTE_UNARY_REMOVE_IF(6) 192 EXECUTE_UNARY_REMOVE_IF(7) 193 EXECUTE_UNARY_REMOVE_IF(8) 194 default:throw std::invalid_argument("only support up to 8 foreign tables"); 195 } 196 } 197 198 // run binary filter. 199 template<typename FunctorType> 200 template<typename LHSIterator, typename RHSIterator, typename IndexZipIterator> 201 int FilterContext<FunctorType>::executeRemoveIf( 202 LHSIterator lhsIter, 203 RHSIterator rhsIter, 204 IndexZipIterator indexZipIterator) { 205 typedef typename common_type< 206 typename LHSIterator::value_type::head_type, 207 typename RHSIterator::value_type::head_type>::type InputValueType; 208 BinaryPredicateFunctor<bool, InputValueType> f(functorType); 209 RemoveFilter<typename IndexZipIterator::value_type, uint8_t> removeFilter( 210 predicateVector); 211 212 // first compute the predicate values. 213 thrust::transform(GET_EXECUTION_POLICY(cudaStream), lhsIter, 214 lhsIter + indexVectorLength, rhsIter, predicateVector, f); 215 // then we use the predicate values to remove indexes in place. 216 return thrust::remove_if(GET_EXECUTION_POLICY(cudaStream), indexZipIterator, 217 indexZipIterator + indexVectorLength, removeFilter) - 218 indexZipIterator; 219 } 220 221 // template partial specialization with output iterator as uint8_t* for binary 222 // transform. 223 template<typename FunctorType> 224 template<typename LHSIterator, typename RHSIterator> 225 int FilterContext<FunctorType>::run(uint32_t *indexVector, 226 LHSIterator lhsIter, 227 RHSIterator rhsIter) { 228 switch (numForeignTables) { 229 #define EXECUTE_BINARY_REMOVE_IF(NumTotalForeignTables) \ 230 case NumTotalForeignTables: { \ 231 IndexZipIteratorMaker<NumTotalForeignTables> maker; \ 232 return executeRemoveIf(lhsIter, rhsIter, maker.make(indexVector, \ 233 foreignTableRecordIDVectors)); \ 234 } 235 236 EXECUTE_BINARY_REMOVE_IF(0) 237 EXECUTE_BINARY_REMOVE_IF(1) 238 EXECUTE_BINARY_REMOVE_IF(2) 239 EXECUTE_BINARY_REMOVE_IF(3) 240 EXECUTE_BINARY_REMOVE_IF(4) 241 EXECUTE_BINARY_REMOVE_IF(5) 242 EXECUTE_BINARY_REMOVE_IF(6) 243 EXECUTE_BINARY_REMOVE_IF(7) 244 EXECUTE_BINARY_REMOVE_IF(8) 245 default:throw std::invalid_argument("only support up to 8 foreign tables"); 246 } 247 } 248 249 } // namespace ares