github.com/rohankumardubey/aresdb@v0.0.2-0.20190517170215-e54e3ca06b9c/query/binder.hpp (about) 1 // Copyright (c) 2017-2018 Uber Technologies, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #ifndef QUERY_BINDER_HPP_ 16 #define QUERY_BINDER_HPP_ 17 #include <cuda_runtime.h> 18 #include <cfloat> 19 #include <cstdint> 20 #include <tuple> 21 #include <type_traits> 22 #include <vector> 23 #include "query/algorithm.hpp" 24 #include "query/functor.hpp" 25 #include "query/iterator.hpp" 26 #include "query/memory.hpp" 27 #include "query/time_series_aggregate.h" 28 29 namespace ares { 30 31 template<typename Value> 32 ForeignTableIterator<Value> *prepareForeignTableIterators( 33 int32_t numBatches, 34 VectorPartySlice *vpSlices, 35 size_t stepBytes, 36 bool hasDefault, 37 Value defaultValue, cudaStream_t stream); 38 39 // Forward declaration; 40 template<typename Context, int NumVectors, int NumUnboundIterators> 41 class InputVectorBinderBase; 42 43 // InputVectorBinder will bind NumVectors InputVector struct to different type 44 // of iterators. When there is no more input vector to bind, it will call 45 // context's run function with all the bound iterator. 46 // A example usage is like: 47 // ares::FilterContext<UnaryFunctorType> ctx(predicateVector, 48 // indexVectorLength, 49 // foreignTableRecordIDVectors, 50 // numForeignTables, 51 // functorType, 52 // cudaStream); 53 // std::vector<InputVector> inputVectors = {input}; 54 // ares::InputVectorBinder<ares::FilterContext<UnaryFunctorType>, 1> 55 // binder(ctx, inputVectors, indexVector, baseCounts, startCount); 56 // resHandle.res = 57 // reinterpret_cast<void *>(binder.bind()); 58 // 59 // If the caller need to specialize the binder with a context, they will always 60 template<typename Context, int NumVectors> 61 class InputVectorBinder : public InputVectorBinderBase<Context, 62 NumVectors, 63 NumVectors> { 64 typedef InputVectorBinderBase<Context, NumVectors, NumVectors> super_t; 65 public: 66 explicit InputVectorBinder(Context context, 67 std::vector<InputVector> inputVectors, 68 uint32_t *indexVector, uint32_t *baseCounts, 69 uint32_t startCount) : super_t(context, 70 inputVectors, 71 indexVector, 72 baseCounts, 73 startCount) { 74 } 75 }; 76 77 // InputIteratorBinderBase is the class to bind InputVector structs into 78 // individual input iterators. It will bind one input vector at one time until 79 // N becomes zero. 80 template<typename Context, int NumVectors, int NumUnboundIterators> 81 class InputVectorBinderBase { 82 public: 83 explicit InputVectorBinderBase(Context context, 84 std::vector<InputVector> inputVectors, 85 uint32_t *indexVector, 86 uint32_t *baseCounts, 87 uint32_t startCount) : 88 context(context), 89 inputVectors(inputVectors), 90 indexVector(indexVector), 91 baseCounts(baseCounts), 92 startCount(startCount) {} 93 94 protected: 95 Context context; 96 std::vector<InputVector> inputVectors; 97 uint32_t *indexVector; 98 uint32_t *baseCounts; 99 uint32_t startCount; 100 101 template<typename ...InputIterators> 102 int bindGeneric(InputIterators... boundInputIterators) { 103 InputVectorBinderBase<Context, NumVectors, NumUnboundIterators - 1> 104 nextBinder(context, inputVectors, indexVector, baseCounts, startCount); 105 106 InputVector input = inputVectors[NumVectors - NumUnboundIterators]; 107 108 #define BIND_CONSTANT_INPUT(defaultValue, isValid) \ 109 return nextBinder.bind( \ 110 boundInputIterators..., \ 111 make_constant_iterator( \ 112 defaultValue, isValid)); 113 114 if (input.Type == ConstantInput) { 115 ConstantVector constant = input.Vector.Constant; 116 if (constant.DataType == ConstInt) { 117 BIND_CONSTANT_INPUT(constant.Value.IntVal, constant.IsValid) 118 } else if (constant.DataType == ConstFloat) { 119 BIND_CONSTANT_INPUT(constant.Value.FloatVal, constant.IsValid) 120 } 121 } else if (input.Type == ScratchSpaceInput) { 122 ScratchSpaceVector scratchSpace = input.Vector.ScratchSpace; 123 uint32_t nullsOffset = scratchSpace.NullsOffset; 124 125 #define BIND_SCRATCH_SPACE_INPUT(dataType) \ 126 return nextBinder.bind( \ 127 boundInputIterators..., \ 128 make_scratch_space_input_iterator<dataType>( \ 129 scratchSpace.Values, \ 130 nullsOffset)); 131 132 switch (scratchSpace.DataType) { 133 case Int32: 134 BIND_SCRATCH_SPACE_INPUT(int32_t) 135 case Uint32: 136 BIND_SCRATCH_SPACE_INPUT(uint32_t) 137 case Float32: 138 BIND_SCRATCH_SPACE_INPUT(float_t) 139 default: 140 throw std::invalid_argument( 141 "Unsupported data type for ScratchSpaceInput"); 142 } 143 } else if (input.Type == ForeignColumnInput) { 144 // Note: for now foreign vectors are dimension table columns 145 // that are not compressed nor pre sliced 146 RecordID *recordIDs = input.Vector.ForeignVP.RecordIDs; 147 const int32_t numBatches = input.Vector.ForeignVP.NumBatches; 148 const int32_t baseBatchID = input.Vector.ForeignVP.BaseBatchID; 149 VectorPartySlice *vpSlices = input.Vector.ForeignVP.Batches; 150 const int32_t numRecordsInLastBatch = 151 input.Vector.ForeignVP.NumRecordsInLastBatch; 152 int16_t *const timezoneLookup = input.Vector.ForeignVP.TimezoneLookup; 153 int16_t timezoneLookupSize = input.Vector.ForeignVP.TimezoneLookupSize; 154 DataType dataType = input.Vector.ForeignVP.DataType; 155 bool hasDefault = input.Vector.ForeignVP.DefaultValue.HasDefault; 156 DefaultValue defaultValueStruct = input.Vector.ForeignVP.DefaultValue; 157 uint8_t stepInBytes = getStepInBytes(dataType); 158 159 switch (dataType) { 160 #define BIND_FOREIGN_COLUMN_INPUT(defaultValue, dataType) \ 161 ForeignTableIterator<dataType> *vpIters = \ 162 prepareForeignTableIterators(numBatches, vpSlices, stepInBytes, \ 163 hasDefault, defaultValue, context.getStream()); \ 164 int res = nextBinder.bind(boundInputIterators..., \ 165 RecordIDJoinIterator<dataType>( \ 166 recordIDs, numBatches, baseBatchID, \ 167 vpIters, numRecordsInLastBatch, \ 168 timezoneLookup, timezoneLookupSize)); \ 169 deviceFree(vpIters); \ 170 return res; 171 172 case Bool: { 173 BIND_FOREIGN_COLUMN_INPUT(defaultValueStruct.Value.BoolVal, bool) 174 } 175 case Int8: 176 case Int16: 177 case Int32: { 178 BIND_FOREIGN_COLUMN_INPUT( 179 defaultValueStruct.Value.Int32Val, int32_t) 180 } 181 case Uint8: 182 case Uint16: 183 case Uint32: { 184 BIND_FOREIGN_COLUMN_INPUT( 185 defaultValueStruct.Value.Uint32Val, uint32_t) 186 } 187 case Float32: { 188 BIND_FOREIGN_COLUMN_INPUT( 189 defaultValueStruct.Value.FloatVal, float_t) 190 } 191 default: 192 throw std::invalid_argument( 193 "Unsupported data type for VectorPartyInput: " + 194 std::to_string(__LINE__)); 195 } 196 } 197 198 VectorPartySlice inputVP = input.Vector.VP; 199 bool hasDefault = inputVP.DefaultValue.HasDefault; 200 bool isConstant = inputVP.BasePtr == nullptr; 201 DefaultValue defaultValue = inputVP.DefaultValue; 202 203 if (isConstant) { 204 switch (inputVP.DataType) { 205 case Bool: 206 BIND_CONSTANT_INPUT(defaultValue.Value.BoolVal, hasDefault) 207 case Int8: 208 case Int16: 209 case Int32: 210 BIND_CONSTANT_INPUT(defaultValue.Value.Int32Val, hasDefault) 211 case Uint8: 212 case Uint16: 213 case Uint32: 214 BIND_CONSTANT_INPUT(defaultValue.Value.Uint32Val, hasDefault) 215 case Float32: 216 BIND_CONSTANT_INPUT(defaultValue.Value.FloatVal, hasDefault) 217 default: 218 throw std::invalid_argument( 219 "Unsupported data type for VectorPartyInput: " + 220 std::to_string(__LINE__)); 221 } 222 } 223 224 // Non constant. 225 uint8_t *basePtr = inputVP.BasePtr; 226 uint32_t nullsOffset = inputVP.NullsOffset; 227 uint32_t valuesOffset = inputVP.ValuesOffset; 228 uint8_t startingIndex = inputVP.StartingIndex; 229 uint8_t stepInBytes = getStepInBytes(inputVP.DataType); 230 uint32_t length = inputVP.Length; 231 switch (inputVP.DataType) { 232 #define BIND_COLUMN_INPUT(dataType) \ 233 return nextBinder.bind(boundInputIterators..., \ 234 make_column_iterator<dataType>(indexVector, \ 235 baseCounts, \ 236 startCount, \ 237 basePtr, \ 238 nullsOffset, \ 239 valuesOffset, \ 240 length, \ 241 stepInBytes, \ 242 startingIndex)); 243 case Bool: 244 BIND_COLUMN_INPUT(bool) 245 case Int8: 246 case Int16: 247 case Int32: 248 BIND_COLUMN_INPUT(int32_t) 249 case Uint8: 250 case Uint16: 251 case Uint32: 252 BIND_COLUMN_INPUT(uint32_t) 253 case Float32: 254 BIND_COLUMN_INPUT(float_t) 255 default: 256 throw std::invalid_argument( 257 "Unsupported data type for VectorPartyInput: " + 258 std::to_string(__LINE__)); 259 } 260 } 261 262 template<typename GeoIterator> 263 int bindGeoPoint(GeoIterator geoIter) { 264 InputVectorBinderBase<Context, NumVectors, NumUnboundIterators - 1> 265 nextBinder(context, inputVectors, indexVector, baseCounts, startCount); 266 267 InputVector input = inputVectors[NumVectors - NumUnboundIterators]; 268 if (input.Type == ConstantInput) { 269 ConstantVector constant = input.Vector.Constant; 270 if (constant.DataType == ConstGeoPoint) { 271 return nextBinder.bind( 272 geoIter, 273 thrust::make_constant_iterator( 274 thrust::make_tuple<GeoPointT, bool>( 275 constant.Value.GeoPointVal, constant.IsValid))); 276 } 277 } 278 throw std::invalid_argument( 279 "Unsupported data type " + std::to_string(__LINE__) 280 + "when value type of first input iterator is GeoPoint"); 281 } 282 283 public: 284 template<typename ...InputIterators> 285 int bind(InputIterators... boundInputIterators) { 286 return bindGeneric(boundInputIterators...); 287 } 288 289 // when this is the first input iterator, we allow geo point iterator and uuid 290 // iterator. 291 int bind() { 292 InputVectorBinderBase<Context, NumVectors, NumUnboundIterators - 1> 293 nextBinder(context, inputVectors, indexVector, baseCounts, startCount); 294 InputVector input = inputVectors[NumVectors - NumUnboundIterators]; 295 if (input.Type == VectorPartyInput) { 296 VectorPartySlice inputVP = input.Vector.VP; 297 DataType dataType = inputVP.DataType; 298 uint8_t *basePtr = inputVP.BasePtr; 299 bool hasDefault = inputVP.DefaultValue.HasDefault; 300 DefaultValue defaultValue = inputVP.DefaultValue; 301 uint32_t nullsOffset = inputVP.NullsOffset; 302 uint32_t valuesOffset = inputVP.ValuesOffset; 303 uint8_t startingIndex = inputVP.StartingIndex; 304 uint8_t stepInBytes = getStepInBytes(inputVP.DataType); 305 uint32_t length = inputVP.Length; 306 // This macro will bind column type with width > 4 bytes (GeoPoint, UUID 307 // int64). Since our scratch space is always 4 bytes (int32, uint32, 308 // float), parent nodes for those wider types must be a root node. 309 310 #define BIND_WIDER_COLUMN_INPUT(dataType, defaultValue) \ 311 if (basePtr == nullptr) { \ 312 return nextBinder.bind(thrust::make_constant_iterator( \ 313 thrust::make_tuple<dataType, bool>( \ 314 defaultValue, hasDefault))); \ 315 } \ 316 return nextBinder.bind(make_column_iterator<dataType>( \ 317 indexVector, baseCounts, startCount, basePtr, nullsOffset, \ 318 valuesOffset, length, stepInBytes, startingIndex)); 319 320 switch (dataType) { 321 case GeoPoint: 322 BIND_WIDER_COLUMN_INPUT(GeoPointT, defaultValue.Value.GeoPointVal) 323 case UUID: 324 BIND_WIDER_COLUMN_INPUT(UUIDT, defaultValue.Value.UUIDVal) 325 case Int64: 326 BIND_WIDER_COLUMN_INPUT(int64_t, defaultValue.Value.Int64Val) 327 default: break; 328 } 329 } else if (input.Type == ForeignColumnInput) { 330 RecordID *recordIDs = input.Vector.ForeignVP.RecordIDs; 331 const int32_t numBatches = input.Vector.ForeignVP.NumBatches; 332 const int32_t baseBatchID = input.Vector.ForeignVP.BaseBatchID; 333 VectorPartySlice *vpSlices = input.Vector.ForeignVP.Batches; 334 const int32_t numRecordsInLastBatch = 335 input.Vector.ForeignVP.NumRecordsInLastBatch; 336 DataType dataType = input.Vector.ForeignVP.DataType; 337 bool hasDefault = input.Vector.ForeignVP.DefaultValue.HasDefault; 338 DefaultValue defaultValueStruct = input.Vector.ForeignVP.DefaultValue; 339 uint8_t stepInBytes = getStepInBytes(dataType); 340 341 #define BIND_WIDER_FOREIGN_COLUMN_INPUT(defaultValue, dataType) { \ 342 ForeignTableIterator<dataType> *vpIters = \ 343 prepareForeignTableIterators(numBatches, vpSlices, stepInBytes, \ 344 hasDefault, defaultValue, context.getStream()); \ 345 int res = nextBinder.bind(RecordIDJoinIterator<dataType>( \ 346 recordIDs, numBatches, baseBatchID, \ 347 vpIters, numRecordsInLastBatch, \ 348 nullptr, 0)); \ 349 deviceFree(vpIters); \ 350 return res; \ 351 } 352 353 switch (dataType) { 354 case UUID: BIND_WIDER_FOREIGN_COLUMN_INPUT( 355 defaultValueStruct.Value.UUIDVal, UUIDT) 356 case Int64: BIND_WIDER_FOREIGN_COLUMN_INPUT( 357 defaultValueStruct.Value.Int64Val, int64_t) 358 default: break; 359 } 360 } 361 return bindGeneric(); 362 } 363 364 // UUID data type is only supported in UnaryTransform 365 template <typename UUIDIterator> 366 typename std::enable_if< 367 std::is_same<typename UUIDIterator::value_type::head_type, UUIDT>::value, 368 int>::type 369 bind(UUIDIterator uuidIter) { 370 throw std::invalid_argument( 371 "UUID data type is only supported in UnaryTransform" + 372 std::to_string(__LINE__)); 373 } 374 375 // Int64 data type is only supported in UnaryTransform 376 template <typename Int64Iterator> 377 typename std::enable_if< 378 std::is_same<typename Int64Iterator::value_type::head_type, 379 int64_t>::value, 380 int>::type 381 bind(Int64Iterator int64Iter) { 382 throw std::invalid_argument( 383 "int64 data type is only supported in UnaryTransform" + 384 std::to_string(__LINE__)); 385 } 386 387 // Special handling if the first input iter is a geo iter. 388 template<typename GeoIterator> 389 typename std::enable_if< 390 std::is_same<typename GeoIterator::value_type::head_type, 391 GeoPointT>::value, int>::type bind( 392 GeoIterator geoIter) { 393 return bindGeoPoint(geoIter); 394 } 395 }; 396 397 // This class is called when there is no more unbound iterators. It will just 398 // call context.run to do actual calculation. 399 template<typename Context, int NumVectors> 400 class InputVectorBinderBase<Context, NumVectors, 0> { 401 public: 402 explicit InputVectorBinderBase(Context context, 403 std::vector<InputVector> inputVectors, 404 uint32_t *indexVector, uint32_t *baseCounts, 405 uint32_t startCount) : 406 context(context), 407 inputVectors(inputVectors), 408 indexVector(indexVector), 409 baseCounts(baseCounts), 410 startCount(startCount) {} 411 412 protected: 413 Context context; 414 std::vector<InputVector> inputVectors; 415 uint32_t *indexVector; 416 uint32_t *baseCounts; 417 uint32_t startCount; 418 419 public: 420 template<typename ...InputIterators> 421 int bind(InputIterators... boundInputIterators) { 422 return context.run(indexVector, boundInputIterators...); 423 } 424 }; 425 426 template<typename Value> 427 ForeignTableIterator<Value> *prepareForeignTableIterators( 428 int32_t numBatches, 429 VectorPartySlice *vpSlices, 430 size_t stepBytes, 431 bool hasDefault, 432 Value defaultValue, cudaStream_t stream) { 433 typedef ForeignTableIterator<Value> ValueIter; 434 int totalSize = sizeof(ValueIter) * numBatches; 435 ValueIter* batches; 436 std::vector<ValueIter> batchesH(numBatches); 437 for (int i = 0; i < numBatches; i++) { 438 VectorPartySlice inputVP = vpSlices[i]; 439 if (inputVP.BasePtr == nullptr) { 440 batchesH[i] = ValueIter( 441 make_constant_iterator(defaultValue, 442 hasDefault)); 443 } else { 444 batchesH[i] = 445 ValueIter( 446 VectorPartyIterator<Value>( 447 nullptr, 448 0, 449 inputVP.BasePtr, 450 inputVP.NullsOffset, 451 inputVP.ValuesOffset, 452 inputVP.Length, 453 stepBytes, 454 inputVP.StartingIndex)); 455 } 456 } 457 458 // In host mode we actually don't need to make another copy 459 // of the iterators but we still do here to keep the same code 460 // for host and device mode. 461 ValueIter *vpItersDevice; 462 deviceMalloc(reinterpret_cast<void **>( 463 &vpItersDevice), totalSize); 464 ares::asyncCopyHostToDevice(reinterpret_cast<void *>(vpItersDevice), 465 reinterpret_cast<void *>(&batchesH[0]), totalSize, stream); 466 batches = vpItersDevice; 467 return batches; 468 } 469 470 // IndexZipIteratorMapper is the mapper to map number of foreign tables to 471 // the actual zip iterator 472 template<int NumTotalForeignTables> 473 struct IndexZipIteratorMapper { 474 typedef thrust::zip_iterator< 475 thrust::tuple < thrust::counting_iterator 476 < uint32_t>, uint32_t*>> type; 477 }; 478 479 template<> 480 struct IndexZipIteratorMapper<1> { 481 typedef thrust::zip_iterator< 482 thrust::tuple < thrust::counting_iterator 483 < uint32_t>, uint32_t*, RecordID*>> type; 484 }; 485 486 template<> 487 struct IndexZipIteratorMapper<2> { 488 typedef thrust::zip_iterator< 489 thrust::tuple < thrust::counting_iterator 490 < uint32_t>, uint32_t*, RecordID*, RecordID*>> type; 491 }; 492 493 template<> 494 struct IndexZipIteratorMapper<3> { 495 typedef thrust::zip_iterator< 496 thrust::tuple < thrust::counting_iterator 497 < uint32_t>, uint32_t*, RecordID*, RecordID*, RecordID*>> type; 498 }; 499 500 template<> 501 struct IndexZipIteratorMapper<4> { 502 typedef thrust::zip_iterator< 503 thrust::tuple < thrust::counting_iterator 504 < uint32_t>, uint32_t*, 505 RecordID*, RecordID*, RecordID*, RecordID*>> type; 506 }; 507 508 template<> 509 struct IndexZipIteratorMapper<5> { 510 typedef thrust::zip_iterator< 511 thrust::tuple < thrust::counting_iterator 512 < uint32_t>, uint32_t*, 513 RecordID*, RecordID*, RecordID*, RecordID*, RecordID*>> type; 514 }; 515 516 template<> 517 struct IndexZipIteratorMapper<6> { 518 typedef thrust::zip_iterator< 519 thrust::tuple < thrust::counting_iterator 520 < uint32_t>, uint32_t*, 521 RecordID*, RecordID*, RecordID*, 522 RecordID*, RecordID*, RecordID*>> type; 523 }; 524 525 template<> 526 struct IndexZipIteratorMapper<7> { 527 typedef thrust::zip_iterator< 528 thrust::tuple < thrust::counting_iterator 529 < uint32_t>, uint32_t*, 530 RecordID*, RecordID*, RecordID*, 531 RecordID*, RecordID*, RecordID*, RecordID*>> type; 532 }; 533 534 template<> 535 struct IndexZipIteratorMapper<8> { 536 typedef thrust::zip_iterator< 537 thrust::tuple < thrust::counting_iterator 538 < uint32_t>, uint32_t*, RecordID*, RecordID*, RecordID*, RecordID*, 539 RecordID*, RecordID*, RecordID*, RecordID*>> type; 540 }; 541 542 // IndexZipIteratorMaker is the factory to make the index zip iterator given 543 // a counting iterator, a main table index vector and 0 or more foreign table 544 // record id vector. It binds one record id vector at one time from the 545 // unboundForeignTableRecordIDVectors. If the NUnboundForeignTable is 0, it will 546 // just return the tuple 547 template<int NumTotalForeignTables, int NumUnboundForeignTables> 548 struct IndexZipIteratorMakerBase { 549 template<typename... RecordIDVector> 550 typename IndexZipIteratorMapper<NumTotalForeignTables>::type 551 make(uint32_t *index_vector, 552 RecordID **unboundForeignTableRecordIDVectors, 553 RecordIDVector... boundForeignRecordIDVectors) { 554 IndexZipIteratorMakerBase<NumTotalForeignTables, 555 NumUnboundForeignTables - 1> 556 nextMaker; 557 return nextMaker.make( 558 index_vector, 559 unboundForeignTableRecordIDVectors, 560 boundForeignRecordIDVectors..., 561 unboundForeignTableRecordIDVectors[ 562 NumTotalForeignTables 563 - NumUnboundForeignTables]); 564 } 565 }; 566 567 // Specialized IndexZipIteratorMakerBase with NUnboundForeignTable to be zero. 568 // Just bind everything together using thrust::make_tuple and return. 569 template<int NumTotalForeignTables> 570 struct IndexZipIteratorMakerBase<NumTotalForeignTables, 0> { 571 template<typename... RecordIDVector> 572 typename IndexZipIteratorMapper<NumTotalForeignTables>::type 573 make(uint32_t *indexVector, 574 RecordID **unboundForeignTableRecordIDVectors, 575 RecordIDVector... boundForeignRecordIDVectors) { 576 return thrust::make_zip_iterator(thrust::make_tuple( 577 thrust::counting_iterator<uint32_t>(0), indexVector, 578 boundForeignRecordIDVectors...)); 579 } 580 }; 581 582 template<int NumTotalForeignTables> 583 struct IndexZipIteratorMaker : public IndexZipIteratorMakerBase< 584 NumTotalForeignTables, 585 NumTotalForeignTables> { 586 }; 587 588 } // namespace ares 589 #endif // QUERY_BINDER_HPP_