github.com/rohankumardubey/aresdb@v0.0.2-0.20190517170215-e54e3ca06b9c/query/hash_lookup.cu (about)

     1  //  Copyright (c) 2017-2018 Uber Technologies, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  #include <cstring>
    16  #include <algorithm>
    17  #include <exception>
    18  #include <vector>
    19  #include "query/transform.hpp"
    20  #include "query/binder.hpp"
    21  
    22  namespace ares {
    23  
    24  class HashLookupContext {
    25   public:
    26    HashLookupContext(int indexVectorLength, CuckooHashIndex hashIndex,
    27                      RecordID* recordIDVector,
    28                      void *cudaStream)
    29        : indexVectorLength(indexVectorLength),
    30          hashIndex(hashIndex),
    31          recordIDVector(recordIDVector),
    32          cudaStream(reinterpret_cast<cudaStream_t>(cudaStream)) {}
    33  
    34    template<typename InputIterator>
    35    int run(uint32_t *indexVector, InputIterator inputIterator);
    36  
    37    cudaStream_t getStream() const {
    38      return cudaStream;
    39    }
    40  
    41   private:
    42    int indexVectorLength;
    43    CuckooHashIndex hashIndex;
    44    RecordID* recordIDVector;
    45    cudaStream_t cudaStream;
    46  };
    47  
    48  // Specialized for HashLookupContext.
    49  template <>
    50  class InputVectorBinder<HashLookupContext, 1> : public InputVectorBinderBase<
    51      HashLookupContext, 1, 1> {
    52    typedef InputVectorBinderBase<HashLookupContext, 1, 1> super_t;
    53   public:
    54    explicit InputVectorBinder(HashLookupContext context,
    55                               std::vector<InputVector> inputVectors,
    56                               uint32_t *indexVector, uint32_t *baseCounts,
    57                               uint32_t startCount) : super_t(context,
    58                                                              inputVectors,
    59                                                              indexVector,
    60                                                              baseCounts,
    61                                                              startCount) {
    62    }
    63   public:
    64    template<typename ...InputIterators>
    65    int bind(InputIterators... boundInputIterators);
    66  };
    67  
    68  }  // namespace ares
    69  
    70  CGoCallResHandle HashLookup(InputVector input, RecordID *output,
    71                              uint32_t *indexVector, int indexVectorLength,
    72                              uint32_t *baseCounts, uint32_t startCount,
    73                              CuckooHashIndex hashIndex, void *cudaStream,
    74                              int device) {
    75    CGoCallResHandle resHandle = {nullptr, nullptr};
    76    try {
    77  #ifdef RUN_ON_DEVICE
    78      cudaSetDevice(device);
    79  #endif
    80      ares::HashLookupContext ctx(indexVectorLength,
    81                                  hashIndex, output, cudaStream);
    82      std::vector<InputVector> inputVectors = {input};
    83      ares::InputVectorBinder<ares::HashLookupContext, 1>
    84          binder(ctx, inputVectors, indexVector, baseCounts, startCount);
    85      resHandle.res =
    86          reinterpret_cast<void *>(binder.bind());
    87      CheckCUDAError("HashLookup");
    88    } catch (std::exception &e) {
    89      std::cerr << "Exception happened when doing HashLookup:" << e.what()
    90                << std::endl;
    91      resHandle.pStrErr = strdup(e.what());
    92    }
    93    return resHandle;
    94  }
    95  
    96  
    97  namespace ares {
    98  
    99  // Specialized version for hash lookup to support
   100  // UUID.
   101  template<typename ...InputIterators>
   102  int InputVectorBinder<HashLookupContext, 1>::bind(
   103      InputIterators... boundInputIterators) {
   104    InputVector input = super_t::inputVectors[0];
   105    uint32_t *indexVector = super_t::indexVector;
   106    uint32_t *baseCounts = super_t::baseCounts;
   107    uint32_t startCount = super_t::startCount;
   108  
   109    if (input.Type == VectorPartyInput) {
   110      InputVectorBinderBase<HashLookupContext, 1, 0>
   111          nextBinder(context, inputVectors, indexVector, baseCounts, startCount);
   112      VectorPartySlice inputVP = input.Vector.VP;
   113      if (inputVP.DataType == UUID) {
   114        uint8_t *basePtr = inputVP.BasePtr;
   115        // Treat mode 0 as constant vector.
   116        if (basePtr == nullptr) {
   117          bool hasDefault = inputVP.DefaultValue.HasDefault;
   118          DefaultValue defaultValue = inputVP.DefaultValue;
   119          return nextBinder.bind(boundInputIterators...,
   120                                 thrust::make_constant_iterator(
   121                                     thrust::make_tuple<
   122                                         UUIDT,
   123                                         bool>(
   124                                         defaultValue.Value.UUIDVal,
   125                                         hasDefault)));
   126        }
   127  
   128        uint32_t nullsOffset = inputVP.NullsOffset;
   129        uint32_t valuesOffset = inputVP.ValuesOffset;
   130        uint8_t startingIndex = inputVP.StartingIndex;
   131        uint8_t stepInBytes = getStepInBytes(inputVP.DataType);
   132        uint32_t length = inputVP.Length;
   133        return nextBinder.bind(boundInputIterators...,
   134                                      make_column_iterator<UUIDT>(indexVector,
   135                                                                  baseCounts,
   136                                                                  startCount,
   137                                                                  basePtr,
   138                                                                  nullsOffset,
   139                                                                  valuesOffset,
   140                                                                  length,
   141                                                                  stepInBytes,
   142                                                                  startingIndex));
   143      }
   144    }
   145    return super_t::bind(boundInputIterators...);
   146  }
   147  
   148  template<typename InputIterator>
   149  int HashLookupContext::run(uint32_t *indexVector, InputIterator inputIter) {
   150    typedef typename InputIterator::value_type::head_type InputValueType;
   151    HashLookupFunctor<InputValueType> f(hashIndex.buckets, hashIndex.seeds,
   152                                        hashIndex.keyBytes, hashIndex.numHashes,
   153                                        hashIndex.numBuckets);
   154    return thrust::transform(GET_EXECUTION_POLICY(cudaStream), inputIter,
   155        inputIter + indexVectorLength, recordIDVector, f) -
   156    recordIDVector;
   157  }
   158  
   159  }  // namespace ares