github.com/rohankumardubey/aresdb@v0.0.2-0.20190517170215-e54e3ca06b9c/query/hash_lookup.cu (about) 1 // Copyright (c) 2017-2018 Uber Technologies, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #include <cstring> 16 #include <algorithm> 17 #include <exception> 18 #include <vector> 19 #include "query/transform.hpp" 20 #include "query/binder.hpp" 21 22 namespace ares { 23 24 class HashLookupContext { 25 public: 26 HashLookupContext(int indexVectorLength, CuckooHashIndex hashIndex, 27 RecordID* recordIDVector, 28 void *cudaStream) 29 : indexVectorLength(indexVectorLength), 30 hashIndex(hashIndex), 31 recordIDVector(recordIDVector), 32 cudaStream(reinterpret_cast<cudaStream_t>(cudaStream)) {} 33 34 template<typename InputIterator> 35 int run(uint32_t *indexVector, InputIterator inputIterator); 36 37 cudaStream_t getStream() const { 38 return cudaStream; 39 } 40 41 private: 42 int indexVectorLength; 43 CuckooHashIndex hashIndex; 44 RecordID* recordIDVector; 45 cudaStream_t cudaStream; 46 }; 47 48 // Specialized for HashLookupContext. 49 template <> 50 class InputVectorBinder<HashLookupContext, 1> : public InputVectorBinderBase< 51 HashLookupContext, 1, 1> { 52 typedef InputVectorBinderBase<HashLookupContext, 1, 1> super_t; 53 public: 54 explicit InputVectorBinder(HashLookupContext context, 55 std::vector<InputVector> inputVectors, 56 uint32_t *indexVector, uint32_t *baseCounts, 57 uint32_t startCount) : super_t(context, 58 inputVectors, 59 indexVector, 60 baseCounts, 61 startCount) { 62 } 63 public: 64 template<typename ...InputIterators> 65 int bind(InputIterators... boundInputIterators); 66 }; 67 68 } // namespace ares 69 70 CGoCallResHandle HashLookup(InputVector input, RecordID *output, 71 uint32_t *indexVector, int indexVectorLength, 72 uint32_t *baseCounts, uint32_t startCount, 73 CuckooHashIndex hashIndex, void *cudaStream, 74 int device) { 75 CGoCallResHandle resHandle = {nullptr, nullptr}; 76 try { 77 #ifdef RUN_ON_DEVICE 78 cudaSetDevice(device); 79 #endif 80 ares::HashLookupContext ctx(indexVectorLength, 81 hashIndex, output, cudaStream); 82 std::vector<InputVector> inputVectors = {input}; 83 ares::InputVectorBinder<ares::HashLookupContext, 1> 84 binder(ctx, inputVectors, indexVector, baseCounts, startCount); 85 resHandle.res = 86 reinterpret_cast<void *>(binder.bind()); 87 CheckCUDAError("HashLookup"); 88 } catch (std::exception &e) { 89 std::cerr << "Exception happened when doing HashLookup:" << e.what() 90 << std::endl; 91 resHandle.pStrErr = strdup(e.what()); 92 } 93 return resHandle; 94 } 95 96 97 namespace ares { 98 99 // Specialized version for hash lookup to support 100 // UUID. 101 template<typename ...InputIterators> 102 int InputVectorBinder<HashLookupContext, 1>::bind( 103 InputIterators... boundInputIterators) { 104 InputVector input = super_t::inputVectors[0]; 105 uint32_t *indexVector = super_t::indexVector; 106 uint32_t *baseCounts = super_t::baseCounts; 107 uint32_t startCount = super_t::startCount; 108 109 if (input.Type == VectorPartyInput) { 110 InputVectorBinderBase<HashLookupContext, 1, 0> 111 nextBinder(context, inputVectors, indexVector, baseCounts, startCount); 112 VectorPartySlice inputVP = input.Vector.VP; 113 if (inputVP.DataType == UUID) { 114 uint8_t *basePtr = inputVP.BasePtr; 115 // Treat mode 0 as constant vector. 116 if (basePtr == nullptr) { 117 bool hasDefault = inputVP.DefaultValue.HasDefault; 118 DefaultValue defaultValue = inputVP.DefaultValue; 119 return nextBinder.bind(boundInputIterators..., 120 thrust::make_constant_iterator( 121 thrust::make_tuple< 122 UUIDT, 123 bool>( 124 defaultValue.Value.UUIDVal, 125 hasDefault))); 126 } 127 128 uint32_t nullsOffset = inputVP.NullsOffset; 129 uint32_t valuesOffset = inputVP.ValuesOffset; 130 uint8_t startingIndex = inputVP.StartingIndex; 131 uint8_t stepInBytes = getStepInBytes(inputVP.DataType); 132 uint32_t length = inputVP.Length; 133 return nextBinder.bind(boundInputIterators..., 134 make_column_iterator<UUIDT>(indexVector, 135 baseCounts, 136 startCount, 137 basePtr, 138 nullsOffset, 139 valuesOffset, 140 length, 141 stepInBytes, 142 startingIndex)); 143 } 144 } 145 return super_t::bind(boundInputIterators...); 146 } 147 148 template<typename InputIterator> 149 int HashLookupContext::run(uint32_t *indexVector, InputIterator inputIter) { 150 typedef typename InputIterator::value_type::head_type InputValueType; 151 HashLookupFunctor<InputValueType> f(hashIndex.buckets, hashIndex.seeds, 152 hashIndex.keyBytes, hashIndex.numHashes, 153 hashIndex.numBuckets); 154 return thrust::transform(GET_EXECUTION_POLICY(cudaStream), inputIter, 155 inputIter + indexVectorLength, recordIDVector, f) - 156 recordIDVector; 157 } 158 159 } // namespace ares