github.com/rohankumardubey/aresdb@v0.0.2-0.20190517170215-e54e3ca06b9c/query/aql_processor.go (about) 1 // Copyright (c) 2017-2018 Uber Technologies, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package query 16 17 import ( 18 "math" 19 "unsafe" 20 21 "encoding/binary" 22 "github.com/uber/aresdb/memstore" 23 memCom "github.com/uber/aresdb/memstore/common" 24 "github.com/uber/aresdb/memutils" 25 queryCom "github.com/uber/aresdb/query/common" 26 "github.com/uber/aresdb/query/expr" 27 "github.com/uber/aresdb/utils" 28 "time" 29 ) 30 31 const ( 32 hllQueryRequiredMemoryInMB = 10 * 1024 33 ) 34 35 // batchTransferExecutor defines the type of the functor to transfer a live batch or a archive batch 36 // from host memory to device memory. hostVPs will be the columns to be released after transfer. startRow 37 // is used to slice the vector party. 38 type batchTransferExecutor func(stream unsafe.Pointer) (deviceColumns []deviceVectorPartySlice, 39 hostVPs []memCom.VectorParty, firstColumn, startRow, totalBytes, numTransfers int) 40 41 // customFilterExecutor is the functor to apply custom filters depends on the batch type. For archive batch, 42 // the custom filter will be the time filter and will only be applied to first or last batch. For live batch, 43 // the custom filters will be the cutoff time filter if cutoff is larger than 0, pre-filters and time filters. 44 type customFilterExecutor func(stream unsafe.Pointer) 45 46 // ProcessQuery processes the compiled query and executes it on GPU. 47 func (qc *AQLQueryContext) ProcessQuery(memStore memstore.MemStore) { 48 defer func() { 49 if r := recover(); r != nil { 50 // find out exactly what the error was and set err 51 switch x := r.(type) { 52 case string: 53 qc.Error = utils.StackError(nil, x) 54 case error: 55 qc.Error = utils.StackError(x, "Panic happens when processing query") 56 default: 57 qc.Error = utils.StackError(nil, "Panic happens when processing query %v", x) 58 } 59 utils.GetLogger().Error("Releasing device memory after panic") 60 qc.Release() 61 } 62 }() 63 64 qc.cudaStreams[0] = memutils.CreateCudaStream(qc.Device) 65 qc.cudaStreams[1] = memutils.CreateCudaStream(qc.Device) 66 qc.OOPK.currentBatch.device = qc.Device 67 qc.OOPK.LiveBatchStats = oopkQueryStats{ 68 Name2Stage: make(map[stageName]*oopkStageSummaryStats), 69 } 70 qc.OOPK.ArchiveBatchStats = oopkQueryStats{ 71 Name2Stage: make(map[stageName]*oopkStageSummaryStats), 72 } 73 74 previousBatchExecutor := NewDummyBatchExecutor() 75 76 start := utils.Now() 77 for joinTableID, join := range qc.Query.Joins { 78 qc.prepareForeignTable(memStore, joinTableID, join) 79 if qc.Error != nil { 80 return 81 } 82 } 83 qc.reportTiming(qc.cudaStreams[0], &start, prepareForeignTableTiming) 84 85 qc.prepareTimezoneTable(memStore) 86 if qc.Error != nil { 87 return 88 } 89 90 // prepare geo intersection 91 if qc.OOPK.geoIntersection != nil { 92 shapeExists := qc.prepareForGeoIntersect(memStore) 93 if qc.Error != nil { 94 return 95 } 96 if !shapeExists { 97 // if no shape exist and geo check for point in shape 98 // no need to continue processing batch 99 if qc.OOPK.geoIntersection.inOrOut { 100 return 101 } 102 // if no shape exist and geo check for point not in shape 103 // no need to do geo intersection 104 qc.OOPK.geoIntersection = nil 105 } 106 } 107 108 for _, shardID := range qc.TableScanners[0].Shards { 109 previousBatchExecutor = qc.processShard(memStore, shardID, previousBatchExecutor) 110 if qc.Error != nil { 111 return 112 } 113 if qc.OOPK.done { 114 break 115 } 116 } 117 118 // query execution for last batch. 119 qc.runBatchExecutor(previousBatchExecutor, true) 120 121 // this code snippet does the followings: 122 // 1. write stats to log. 123 // 2. allocate host buffer for result and copy the result from device to host. 124 // 3. clean up device status buffers if no panic. 125 if qc.Debug { 126 qc.OOPK.LiveBatchStats.writeToLog() 127 qc.OOPK.ArchiveBatchStats.writeToLog() 128 } 129 130 start = utils.Now() 131 if qc.Error == nil { 132 // Copy the result to host memory. 133 qc.OOPK.ResultSize = qc.OOPK.currentBatch.resultSize 134 if qc.OOPK.IsHLL() { 135 qc.HLLQueryResult, qc.Error = qc.PostprocessAsHLLData() 136 } else { 137 // copy dimensions 138 qc.OOPK.dimensionVectorH = memutils.HostAlloc(qc.OOPK.ResultSize * qc.OOPK.DimRowBytes) 139 asyncCopyDimensionVector(qc.OOPK.dimensionVectorH, qc.OOPK.currentBatch.dimensionVectorD[0].getPointer(), qc.OOPK.ResultSize, 0, 140 qc.OOPK.NumDimsPerDimWidth, qc.OOPK.ResultSize, qc.OOPK.currentBatch.resultCapacity, 141 memutils.AsyncCopyDeviceToHost, qc.cudaStreams[0], qc.Device) 142 if !qc.isNonAggregationQuery { 143 // copy measures 144 qc.OOPK.measureVectorH = memutils.HostAlloc(qc.OOPK.ResultSize * qc.OOPK.MeasureBytes) 145 memutils.AsyncCopyDeviceToHost( 146 qc.OOPK.measureVectorH, qc.OOPK.currentBatch.measureVectorD[0].getPointer(), 147 qc.OOPK.ResultSize*qc.OOPK.MeasureBytes, qc.cudaStreams[0], qc.Device) 148 } 149 memutils.WaitForCudaStream(qc.cudaStreams[0], qc.Device) 150 } 151 } 152 qc.reportTiming(qc.cudaStreams[0], &start, resultTransferTiming) 153 qc.cleanUpDeviceStatus() 154 qc.reportTiming(nil, &start, finalCleanupTiming) 155 } 156 157 func (qc *AQLQueryContext) processShard(memStore memstore.MemStore, shardID int, previousBatchExecutor BatchExecutor) BatchExecutor { 158 var liveRecordsProcessed, archiveRecordsProcessed, liveBatchProcessed, archiveBatchProcessed, liveBytesTransferred, archiveBytesTransferred int 159 shard, err := memStore.GetTableShard(qc.Query.Table, shardID) 160 if err != nil { 161 qc.Error = utils.StackError(err, "failed to get shard %d for table %s", 162 shardID, qc.Query.Table) 163 return previousBatchExecutor 164 } 165 defer shard.Users.Done() 166 167 var archiveStore *memstore.ArchiveStoreVersion 168 var cutoff uint32 169 if shard.Schema.Schema.IsFactTable { 170 archiveStore = shard.ArchiveStore.GetCurrentVersion() 171 defer archiveStore.Users.Done() 172 cutoff = archiveStore.ArchivingCutoff 173 } 174 175 // Process live batches. 176 if qc.toTime == nil || cutoff < uint32(qc.toTime.Time.Unix()) { 177 batchIDs, numRecordsInLastBatch := shard.LiveStore.GetBatchIDs() 178 for i, batchID := range batchIDs { 179 if qc.OOPK.done { 180 break 181 } 182 batch := shard.LiveStore.GetBatchForRead(batchID) 183 if batch == nil { 184 continue 185 } 186 187 // For now, dimension table does not persist min and max therefore 188 // we can only skip live batch for fact table. 189 // TODO: Persist min/max/numTrues when snapshotting. 190 if shard.Schema.Schema.IsFactTable && qc.shouldSkipLiveBatch(batch) { 191 batch.RUnlock() 192 qc.OOPK.LiveBatchStats.NumBatchSkipped++ 193 continue 194 } 195 196 liveBatchProcessed++ 197 size := batch.Capacity 198 if i == len(batchIDs)-1 { 199 size = numRecordsInLastBatch 200 } 201 liveRecordsProcessed += size 202 previousBatchExecutor = qc.processBatch(&batch.Batch, 203 batchID, 204 size, 205 qc.transferLiveBatch(batch, size), 206 qc.liveBatchCustomFilterExecutor(cutoff), previousBatchExecutor, true) 207 qc.cudaStreams[0], qc.cudaStreams[1] = qc.cudaStreams[1], qc.cudaStreams[0] 208 liveBytesTransferred += qc.OOPK.currentBatch.stats.bytesTransferred 209 } 210 } 211 212 // Process archive batches. 213 if archiveStore != nil && (qc.fromTime == nil || cutoff > uint32(qc.fromTime.Time.Unix())) { 214 scanner := qc.TableScanners[0] 215 for batchID := scanner.ArchiveBatchIDStart; batchID < scanner.ArchiveBatchIDEnd; batchID++ { 216 if qc.OOPK.done { 217 break 218 } 219 archiveBatch := archiveStore.RequestBatch(int32(batchID)) 220 if archiveBatch.Size == 0 { 221 qc.OOPK.ArchiveBatchStats.NumBatchSkipped++ 222 continue 223 } 224 isFirstOrLast := batchID == scanner.ArchiveBatchIDStart || batchID == scanner.ArchiveBatchIDEnd-1 225 previousBatchExecutor = qc.processBatch( 226 &archiveBatch.Batch, 227 int32(batchID), 228 archiveBatch.Size, 229 qc.transferArchiveBatch(archiveBatch, isFirstOrLast), 230 qc.archiveBatchCustomFilterExecutor(isFirstOrLast), 231 previousBatchExecutor, false) 232 archiveRecordsProcessed += archiveBatch.Size 233 archiveBatchProcessed++ 234 qc.cudaStreams[0], qc.cudaStreams[1] = qc.cudaStreams[1], qc.cudaStreams[0] 235 archiveBytesTransferred += qc.OOPK.currentBatch.stats.bytesTransferred 236 } 237 } 238 utils.GetReporter(qc.Query.Table, shardID).GetCounter(utils.QueryLiveRecordsProcessed).Inc(int64(liveRecordsProcessed)) 239 utils.GetReporter(qc.Query.Table, shardID).GetCounter(utils.QueryArchiveRecordsProcessed).Inc(int64(archiveRecordsProcessed)) 240 utils.GetReporter(qc.Query.Table, shardID).GetCounter(utils.QueryLiveBatchProcessed).Inc(int64(liveBatchProcessed)) 241 utils.GetReporter(qc.Query.Table, shardID).GetCounter(utils.QueryArchiveBatchProcessed).Inc(int64(archiveBatchProcessed)) 242 utils.GetReporter(qc.Query.Table, shardID).GetCounter(utils.QueryLiveBytesTransferred).Inc(int64(liveBytesTransferred)) 243 utils.GetReporter(qc.Query.Table, shardID).GetCounter(utils.QueryArchiveBytesTransferred).Inc(int64(archiveBytesTransferred)) 244 245 return previousBatchExecutor 246 } 247 248 // Release releases all device memory it allocated. It **should only called** when any errors happens while the query is 249 // processed. 250 func (qc *AQLQueryContext) Release() { 251 // release device memory for processing current batch. 252 qc.OOPK.currentBatch.cleanupBeforeAggregation() 253 qc.OOPK.currentBatch.swapResultBufferForNextBatch() 254 qc.cleanUpDeviceStatus() 255 qc.ReleaseHostResultsBuffers() 256 } 257 258 // CleanUpDevice cleans up the device status including 259 // 1. clean up the device buffer for storing results. 260 // 2. clean up the cuda streams 261 func (qc *AQLQueryContext) cleanUpDeviceStatus() { 262 // clean up foreign table memory after query 263 for _, foreignTable := range qc.OOPK.foreignTables { 264 qc.cleanUpForeignTable(foreignTable) 265 } 266 qc.OOPK.foreignTables = nil 267 268 // release geo pointers 269 if qc.OOPK.geoIntersection != nil { 270 deviceFreeAndSetNil(&qc.OOPK.geoIntersection.shapeLatLongs) 271 } 272 273 // Destroy streams 274 memutils.DestroyCudaStream(qc.cudaStreams[0], qc.Device) 275 memutils.DestroyCudaStream(qc.cudaStreams[1], qc.Device) 276 qc.cudaStreams = [2]unsafe.Pointer{nil, nil} 277 278 // Clean up the device result buffers. 279 qc.OOPK.currentBatch.cleanupDeviceResultBuffers() 280 281 // Clean up timezone lookup buffer. 282 deviceFreeAndSetNil(&qc.OOPK.currentBatch.timezoneLookupD) 283 } 284 285 // clean up foreign table 286 func (qc *AQLQueryContext) cleanUpForeignTable(table *foreignTable) { 287 if table != nil { 288 deviceFreeAndSetNil(&table.devicePrimaryKeyPtr) 289 for _, batch := range table.batches { 290 for _, column := range batch { 291 deviceFreeAndSetNil(&column.basePtr) 292 } 293 } 294 table.batches = nil 295 } 296 } 297 298 // getGeoShapeLatLongSlice format GeoShapeGo into slices of float32 for query purpose 299 // Lats and Longs are stored in the format as [a1,a2,...an,a1,MaxFloat32,b1,bz,...bn] 300 // refer to time_series_aggregate.h for GeoShape struct 301 func getGeoShapeLatLongSlice(shapesLats, shapesLongs []float32, gs memCom.GeoShapeGo) ([]float32, []float32, int) { 302 numPoints := 0 303 for i, polygon := range gs.Polygons { 304 if len(polygon) > 0 && i > 0 { 305 // write place holder at start of polygon 306 shapesLats = append(shapesLats, math.MaxFloat32) 307 shapesLongs = append(shapesLongs, math.MaxFloat32) 308 // FLT_MAX as placeholder for each polygon 309 numPoints++ 310 } 311 for _, point := range polygon { 312 shapesLats = append(shapesLats, point[0]) 313 shapesLongs = append(shapesLongs, point[1]) 314 numPoints++ 315 } 316 } 317 return shapesLats, shapesLongs, numPoints 318 } 319 320 func (qc *AQLQueryContext) prepareForGeoIntersect(memStore memstore.MemStore) (shapeExists bool) { 321 tableScanner := qc.TableScanners[qc.OOPK.geoIntersection.shapeTableID] 322 shapeColumnID := qc.OOPK.geoIntersection.shapeColumnID 323 tableName := tableScanner.Schema.Schema.Name 324 // geo table is not sharded 325 shard, err := memStore.GetTableShard(tableName, 0) 326 if err != nil { 327 qc.Error = utils.StackError(err, "Failed to get shard for table %s, shard: %d", tableName, 0) 328 return 329 } 330 defer shard.Users.Done() 331 332 numPointsPerShape := make([]int32, 0, len(qc.OOPK.geoIntersection.shapeUUIDs)) 333 qc.OOPK.geoIntersection.validShapeUUIDs = make([]string, 0, len(qc.OOPK.geoIntersection.shapeUUIDs)) 334 var shapesLats, shapesLongs []float32 335 var numPoints, totalNumPoints int 336 for _, uuid := range qc.OOPK.geoIntersection.shapeUUIDs { 337 recordID, found := shard.LiveStore.LookupKey([]string{uuid}) 338 if found { 339 batch := shard.LiveStore.GetBatchForRead(recordID.BatchID) 340 if batch != nil { 341 shapeValue := batch.GetDataValue(int(recordID.Index), shapeColumnID) 342 // compiler should have verified the geo column GeoShape type 343 shapesLats, shapesLongs, numPoints = getGeoShapeLatLongSlice(shapesLats, shapesLongs, *(shapeValue.GoVal.(*memCom.GeoShapeGo))) 344 if numPoints > 0 { 345 totalNumPoints += numPoints 346 numPointsPerShape = append(numPointsPerShape, int32(numPoints)) 347 qc.OOPK.geoIntersection.validShapeUUIDs = append(qc.OOPK.geoIntersection.validShapeUUIDs, uuid) 348 shapeExists = true 349 } 350 batch.RUnlock() 351 } 352 } 353 } 354 355 if !shapeExists { 356 return 357 } 358 359 numValidShapes := len(numPointsPerShape) 360 shapeIndexs := make([]uint8, totalNumPoints) 361 pointIndex := 0 362 for shapeIndex, numPoints := range numPointsPerShape { 363 for i := 0; i < int(numPoints); i++ { 364 shapeIndexs[pointIndex] = uint8(shapeIndex) 365 pointIndex++ 366 } 367 } 368 369 // allocate memory for lats, longs (float32) and numPoints (int32) device vectors 370 latsPtrD := deviceAllocate(totalNumPoints*4*2+totalNumPoints, qc.Device) 371 longsPtrD := latsPtrD.offset(totalNumPoints * 4) 372 shapeIndexsD := longsPtrD.offset(totalNumPoints * 4) 373 374 memutils.AsyncCopyHostToDevice(latsPtrD.getPointer(), unsafe.Pointer(&shapesLats[0]), totalNumPoints*4, qc.cudaStreams[0], qc.Device) 375 memutils.AsyncCopyHostToDevice(longsPtrD.getPointer(), unsafe.Pointer(&shapesLongs[0]), totalNumPoints*4, qc.cudaStreams[0], qc.Device) 376 memutils.AsyncCopyHostToDevice(shapeIndexsD.getPointer(), unsafe.Pointer(&shapeIndexs[0]), totalNumPoints, qc.cudaStreams[0], qc.Device) 377 378 qc.OOPK.geoIntersection.shapeLatLongs = latsPtrD 379 qc.OOPK.geoIntersection.numShapes = numValidShapes 380 qc.OOPK.geoIntersection.totalNumPoints = totalNumPoints 381 return 382 } 383 384 // prepare foreign table (allocate and transfer memory) before processing 385 func (qc *AQLQueryContext) prepareForeignTable(memStore memstore.MemStore, joinTableID int, join Join) { 386 ft := qc.OOPK.foreignTables[joinTableID] 387 if ft == nil { 388 return 389 } 390 391 // join only support dimension table for now 392 // and dimension table is not shared 393 shard, err := memStore.GetTableShard(join.Table, 0) 394 if err != nil { 395 qc.Error = utils.StackError(err, "Failed to get shard for table %s, shard: %d", join.Table, 0) 396 return 397 } 398 defer shard.Users.Done() 399 400 // only need live store for dimension table 401 batchIDs, numRecordsInLastBatch := shard.LiveStore.GetBatchIDs() 402 ft.numRecordsInLastBatch = numRecordsInLastBatch 403 deviceBatches := make([][]deviceVectorPartySlice, len(batchIDs)) 404 405 // transfer primary key 406 hostPrimaryKeyData := shard.LiveStore.PrimaryKey.LockForTransfer() 407 devicePrimaryKeyPtr := deviceAllocate(hostPrimaryKeyData.NumBytes, qc.Device) 408 memutils.AsyncCopyHostToDevice(devicePrimaryKeyPtr.getPointer(), hostPrimaryKeyData.Data, hostPrimaryKeyData.NumBytes, qc.cudaStreams[0], qc.Device) 409 memutils.WaitForCudaStream(qc.cudaStreams[0], qc.Device) 410 ft.hostPrimaryKeyData = hostPrimaryKeyData 411 ft.devicePrimaryKeyPtr = devicePrimaryKeyPtr 412 shard.LiveStore.PrimaryKey.UnlockAfterTransfer() 413 414 // allocate device memory 415 for i, batchID := range batchIDs { 416 batch := shard.LiveStore.GetBatchForRead(batchID) 417 if batch == nil { 418 continue 419 } 420 batchIndex := batchID - memstore.BaseBatchID 421 deviceBatches[batchIndex] = make([]deviceVectorPartySlice, len(qc.TableScanners[joinTableID+1].Columns)) 422 423 size := batch.Capacity 424 if i == len(batchIDs)-1 { 425 size = numRecordsInLastBatch 426 } 427 for i, columnID := range qc.TableScanners[joinTableID+1].Columns { 428 usage := qc.TableScanners[joinTableID+1].ColumnUsages[columnID] 429 if usage&(columnUsedByAllBatches|columnUsedByLiveBatches) != 0 { 430 sourceVP := batch.Columns[columnID] 431 if sourceVP == nil { 432 continue 433 } 434 435 hostVPSlice := sourceVP.(memstore.TransferableVectorParty).GetHostVectorPartySlice(0, size) 436 deviceBatches[batchIndex][i] = hostToDeviceColumn(hostVPSlice, qc.Device) 437 copyHostToDevice(hostVPSlice, deviceBatches[batchIndex][i], qc.cudaStreams[0], qc.Device) 438 } 439 } 440 memutils.WaitForCudaStream(qc.cudaStreams[0], qc.Device) 441 batch.RUnlock() 442 } 443 ft.batches = deviceBatches 444 } 445 446 // prepareTimezoneTable 447 func (qc *AQLQueryContext) prepareTimezoneTable(store memstore.MemStore) { 448 if qc.timezoneTable.tableColumn == "" { 449 return 450 } 451 452 // Timezone table 453 timezoneTableName := utils.GetConfig().Query.TimezoneTable.TableName 454 schema, err := store.GetSchema(timezoneTableName) 455 if err != nil { 456 qc.Error = err 457 return 458 } 459 if schema == nil { 460 qc.Error = utils.StackError(nil, "unknown timezone table %s", timezoneTableName) 461 return 462 } 463 464 timer := utils.GetRootReporter().GetTimer(utils.TimezoneLookupTableCreationTime) 465 start := utils.Now() 466 defer func() { 467 duration := utils.Now().Sub(start) 468 timer.Record(duration) 469 }() 470 471 schema.RLock() 472 defer schema.RUnlock() 473 474 if tzDict, found := schema.EnumDicts[qc.timezoneTable.tableColumn]; found { 475 lookUp := make([]int16, len(tzDict.ReverseDict)) 476 for i := range lookUp { 477 if loc, err := time.LoadLocation(tzDict.ReverseDict[i]); err == nil { 478 _, offset := time.Now().In(loc).Zone() 479 lookUp[i] = int16(offset) 480 } else { 481 qc.Error = utils.StackError(err, "error parsing timezone") 482 return 483 } 484 } 485 sizeInBytes := binary.Size(lookUp) 486 lookupPtr := deviceAllocate(sizeInBytes, qc.Device) 487 memutils.AsyncCopyHostToDevice(lookupPtr.getPointer(), unsafe.Pointer(&lookUp[0]), sizeInBytes, qc.cudaStreams[0], qc.Device) 488 qc.OOPK.currentBatch.timezoneLookupD = lookupPtr 489 qc.OOPK.currentBatch.timezoneLookupDSize = len(lookUp) 490 } else { 491 qc.Error = utils.StackError(nil, "unknown timezone column %s", qc.timezoneTable.tableColumn) 492 return 493 } 494 495 } 496 497 // transferLiveBatch returns a functor to transfer a live batch to device memory. The size parameter will be either the 498 // size of the batch or num records in last batch. hostColumns will always be empty since we should not release a vector 499 // party of a live batch. Start row will always be zero as well. 500 func (qc *AQLQueryContext) transferLiveBatch(batch *memstore.LiveBatch, size int) batchTransferExecutor { 501 return func(stream unsafe.Pointer) (deviceColumns []deviceVectorPartySlice, hostVPs []memCom.VectorParty, 502 firstColumn, startRow, totalBytes, numTransfers int) { 503 // Allocate column inputs. 504 firstColumn = -1 505 deviceColumns = make([]deviceVectorPartySlice, len(qc.TableScanners[0].Columns)) 506 for i, columnID := range qc.TableScanners[0].Columns { 507 usage := qc.TableScanners[0].ColumnUsages[columnID] 508 if usage&(columnUsedByAllBatches|columnUsedByLiveBatches) != 0 { 509 if firstColumn < 0 { 510 firstColumn = i 511 } 512 sourceVP := batch.Columns[columnID] 513 if sourceVP == nil { 514 continue 515 } 516 517 hostColumn := sourceVP.(memstore.TransferableVectorParty).GetHostVectorPartySlice(0, size) 518 deviceColumns[i] = hostToDeviceColumn(hostColumn, qc.Device) 519 b, t := copyHostToDevice(hostColumn, deviceColumns[i], stream, qc.Device) 520 totalBytes += b 521 numTransfers += t 522 } 523 } 524 return 525 } 526 } 527 528 // liveBatchTimeFilterExecutor returns a functor to apply custom time filters to live batch. 529 func (qc *AQLQueryContext) liveBatchCustomFilterExecutor(cutoff uint32) customFilterExecutor { 530 return func(stream unsafe.Pointer) { 531 // cutoff filter evaluation. 532 // only apply to fact table where cutoff > 0 533 if cutoff > 0 { 534 qc.OOPK.currentBatch.processExpression( 535 qc.createCutoffTimeFilter(cutoff), nil, 536 qc.TableScanners, qc.OOPK.foreignTables, stream, qc.Device, qc.OOPK.currentBatch.filterAction) 537 } 538 539 // time filter evaluation 540 for _, filter := range qc.OOPK.TimeFilters { 541 if filter != nil { 542 qc.OOPK.currentBatch.processExpression(filter, nil, 543 qc.TableScanners, qc.OOPK.foreignTables, stream, qc.Device, qc.OOPK.currentBatch.filterAction) 544 } 545 } 546 547 // prefilter evaluation 548 for _, filter := range qc.OOPK.Prefilters { 549 qc.OOPK.currentBatch.processExpression(filter, nil, 550 qc.TableScanners, qc.OOPK.foreignTables, stream, qc.Device, qc.OOPK.currentBatch.filterAction) 551 } 552 } 553 } 554 555 // transferArchiveBatch returns the functor to transfer an archive batch to device memory. We will need to release 556 // hostColumns after transfer completes. 557 func (qc *AQLQueryContext) transferArchiveBatch(batch *memstore.ArchiveBatch, 558 isFirstOrLast bool) batchTransferExecutor { 559 return func(stream unsafe.Pointer) (deviceSlices []deviceVectorPartySlice, hostVPs []memCom.VectorParty, 560 firstColumn, startRow, totalBytes, numTransfers int) { 561 matchedColumnUsages := columnUsedByAllBatches 562 if isFirstOrLast { 563 matchedColumnUsages |= columnUsedByFirstArchiveBatch | columnUsedByLastArchiveBatch 564 } 565 566 // Request columns, prefilter-slicing, allocate column inputs. 567 firstColumn = -1 568 hostVPs = make([]memCom.VectorParty, len(qc.TableScanners[0].Columns)) 569 hostSlices := make([]memCom.HostVectorPartySlice, len(qc.TableScanners[0].Columns)) 570 deviceSlices = make([]deviceVectorPartySlice, len(qc.TableScanners[0].Columns)) 571 endRow := batch.Size 572 prefilterIndex := 0 573 // Must iterate in reverse order to apply prefilter slicing properly. 574 for i := len(qc.TableScanners[0].Columns) - 1; i >= 0; i-- { 575 columnID := qc.TableScanners[0].Columns[i] 576 usage := qc.TableScanners[0].ColumnUsages[columnID] 577 578 if usage&matchedColumnUsages != 0 || usage&columnUsedByPrefilter != 0 { 579 // Request/pin column from disk and wait. 580 vp := batch.RequestVectorParty(columnID) 581 vp.WaitForDiskLoad() 582 583 // prefilter slicing 584 startRow, endRow, hostSlices[i] = qc.prefilterSlice(vp, prefilterIndex, startRow, endRow) 585 prefilterIndex++ 586 587 if usage&matchedColumnUsages != 0 { 588 hostVPs[i] = vp 589 firstColumn = i 590 deviceSlices[i] = hostToDeviceColumn(hostSlices[i], qc.Device) 591 } else { 592 vp.Release() 593 } 594 } 595 } 596 597 for i, dstVPSlice := range deviceSlices { 598 columnID := qc.TableScanners[0].Columns[i] 599 usage := qc.TableScanners[0].ColumnUsages[columnID] 600 if usage&matchedColumnUsages != 0 { 601 srcVPSlice := hostSlices[i] 602 b, t := copyHostToDevice(srcVPSlice, dstVPSlice, stream, qc.Device) 603 totalBytes += b 604 numTransfers += t 605 } 606 } 607 return 608 } 609 } 610 611 // archiveBatchCustomFilterExecutor returns a functor to apply custom filter to first or last archive batch. 612 func (qc *AQLQueryContext) archiveBatchCustomFilterExecutor(isFirstOrLast bool) customFilterExecutor { 613 return func(stream unsafe.Pointer) { 614 if isFirstOrLast { 615 for _, filter := range qc.OOPK.TimeFilters { 616 if filter != nil { 617 qc.OOPK.currentBatch.processExpression(filter, nil, 618 qc.TableScanners, qc.OOPK.foreignTables, stream, qc.Device, qc.OOPK.currentBatch.filterAction) 619 } 620 } 621 } 622 } 623 } 624 625 // helper function for copy dimension vector. Returns the total size of dimension vector. 626 func asyncCopyDimensionVector(toDimVector, fromDimVector unsafe.Pointer, length, offset int, numDimsPerDimWidth queryCom.DimCountsPerDimWidth, 627 toVectorCapacity, fromVectorCapacity int, copyFunc memutils.AsyncMemCopyFunc, 628 stream unsafe.Pointer, device int) { 629 630 ptrFrom, ptrTo := fromDimVector, toDimVector 631 numNullVectors := 0 632 for _, numDims := range numDimsPerDimWidth { 633 numNullVectors += int(numDims) 634 } 635 636 dimBytes := 1 << uint(len(numDimsPerDimWidth)-1) 637 bytesToCopy := length * dimBytes 638 for _, numDim := range numDimsPerDimWidth { 639 for i := 0; i < int(numDim); i++ { 640 ptrTemp := utils.MemAccess(ptrTo, dimBytes*offset) 641 copyFunc(ptrTemp, ptrFrom, bytesToCopy, stream, device) 642 ptrTo = utils.MemAccess(ptrTo, dimBytes*toVectorCapacity) 643 ptrFrom = utils.MemAccess(ptrFrom, dimBytes*fromVectorCapacity) 644 } 645 dimBytes >>= 1 646 bytesToCopy = length * dimBytes 647 } 648 649 // copy null bytes 650 for i := 0; i < numNullVectors; i++ { 651 ptrTemp := utils.MemAccess(ptrTo, offset) 652 copyFunc(ptrTemp, ptrFrom, length, stream, device) 653 ptrTo = utils.MemAccess(ptrTo, toVectorCapacity) 654 ptrFrom = utils.MemAccess(ptrFrom, fromVectorCapacity) 655 } 656 } 657 658 // dimValueVectorSize returns the size of final dim value vector on host side. 659 func dimValResVectorSize(resultSize int, numDimsPerDimWidth queryCom.DimCountsPerDimWidth) int { 660 totalDims := 0 661 for _, numDims := range numDimsPerDimWidth { 662 totalDims += int(numDims) 663 } 664 665 dimBytes := 1 << uint(len(numDimsPerDimWidth)-1) 666 var totalBytes int 667 for _, numDims := range numDimsPerDimWidth { 668 totalBytes += dimBytes * resultSize * int(numDims) 669 dimBytes >>= 1 670 } 671 672 totalBytes += totalDims * resultSize 673 return totalBytes 674 } 675 676 // cleanupDeviceResultBuffers cleans up result buffers and resets result fields. 677 func (bc *oopkBatchContext) cleanupDeviceResultBuffers() { 678 deviceFreeAndSetNil(&bc.dimensionVectorD[0]) 679 deviceFreeAndSetNil(&bc.dimensionVectorD[1]) 680 681 deviceFreeAndSetNil(&bc.dimIndexVectorD[0]) 682 deviceFreeAndSetNil(&bc.dimIndexVectorD[1]) 683 684 deviceFreeAndSetNil(&bc.hashVectorD[0]) 685 deviceFreeAndSetNil(&bc.hashVectorD[1]) 686 687 deviceFreeAndSetNil(&bc.measureVectorD[0]) 688 deviceFreeAndSetNil(&bc.measureVectorD[1]) 689 690 bc.resultSize = 0 691 bc.resultCapacity = 0 692 } 693 694 // clean up memory not used in final aggregation (sort, reduce, hll) 695 // before aggregation happen 696 func (bc *oopkBatchContext) cleanupBeforeAggregation() { 697 for _, column := range bc.columns { 698 deviceFreeAndSetNil(&column.basePtr) 699 } 700 bc.columns = nil 701 702 deviceFreeAndSetNil(&bc.indexVectorD) 703 deviceFreeAndSetNil(&bc.predicateVectorD) 704 deviceFreeAndSetNil(&bc.geoPredicateVectorD) 705 706 for _, recordIDsVector := range bc.foreignTableRecordIDsD { 707 deviceFreeAndSetNil(&recordIDsVector) 708 } 709 bc.foreignTableRecordIDsD = nil 710 711 for _, stackFrame := range bc.exprStackD { 712 deviceFreeAndSetNil(&stackFrame[0]) 713 } 714 bc.exprStackD = nil 715 } 716 717 // swapResultBufferForNextBatch swaps the two 718 // sets of dim/measure/hash vectors to get ready for the next batch. 719 func (bc *oopkBatchContext) swapResultBufferForNextBatch() { 720 bc.size = 0 721 bc.dimensionVectorD[0], bc.dimensionVectorD[1] = bc.dimensionVectorD[1], bc.dimensionVectorD[0] 722 bc.measureVectorD[0], bc.measureVectorD[1] = bc.measureVectorD[1], bc.measureVectorD[0] 723 bc.hashVectorD[0], bc.hashVectorD[1] = bc.hashVectorD[1], bc.hashVectorD[0] 724 } 725 726 // prepareForFiltering prepares the input and the index vectors for filtering. 727 func (bc *oopkBatchContext) prepareForFiltering( 728 columns []deviceVectorPartySlice, firstColumn int, startRow int, stream unsafe.Pointer) { 729 bc.columns = columns 730 bc.startRow = startRow 731 732 if firstColumn >= 0 { 733 bc.size = columns[firstColumn].length 734 // Allocate twice of the size to save number of allocations of temporary index vector. 735 bc.indexVectorD = deviceAllocate(bc.size*4, bc.device) 736 bc.predicateVectorD = deviceAllocate(bc.size, bc.device) 737 bc.baseCountD = columns[firstColumn].counts.offset(columns[firstColumn].countStartIndex * 4) 738 } 739 bc.stats.batchSize = bc.size 740 } 741 742 // prepareForDimAndMeasureEval ensures that dim/measure vectors have enough 743 // capacity for bc.resultSize+bc.size. 744 func (bc *oopkBatchContext) prepareForDimAndMeasureEval( 745 dimRowBytes int, measureBytes int, numDimsPerDimWidth queryCom.DimCountsPerDimWidth, isHLL bool, stream unsafe.Pointer) { 746 if bc.resultSize+bc.size > bc.resultCapacity { 747 oldCapacity := bc.resultCapacity 748 749 bc.resultCapacity = bc.resultSize + bc.size 750 // Extra budget for future proofing. 751 bc.resultCapacity += bc.resultCapacity / 8 752 753 bc.dimensionVectorD = bc.reallocateResultBuffers(bc.dimensionVectorD, dimRowBytes, stream, func(to, from unsafe.Pointer) { 754 asyncCopyDimensionVector(to, from, bc.resultSize, 0, 755 numDimsPerDimWidth, bc.resultCapacity, oldCapacity, 756 memutils.AsyncCopyDeviceToDevice, stream, bc.device) 757 }) 758 759 // uint32_t for index value 760 bc.dimIndexVectorD = bc.reallocateResultBuffers(bc.dimIndexVectorD, 4, stream, nil) 761 // uint64_t for hash value 762 // Note: only when aggregate function is hll, we need to reuse vector[0] 763 if isHLL { 764 bc.hashVectorD = bc.reallocateResultBuffers(bc.hashVectorD, 8, stream, func(to, from unsafe.Pointer) { 765 memutils.AsyncCopyDeviceToDevice(to, from, bc.resultSize*8, stream, bc.device) 766 }) 767 } else { 768 bc.hashVectorD = bc.reallocateResultBuffers(bc.hashVectorD, 8, stream, nil) 769 } 770 771 bc.measureVectorD = bc.reallocateResultBuffers(bc.measureVectorD, measureBytes, stream, func(to, from unsafe.Pointer) { 772 memutils.AsyncCopyDeviceToDevice(to, from, bc.resultSize*measureBytes, stream, bc.device) 773 }) 774 } 775 } 776 777 // reallocateResultBuffers reallocates the result buffer pair to size 778 // resultCapacity*unitBytes and copies resultSize*unitBytes from input[0] to output[0]. 779 func (bc *oopkBatchContext) reallocateResultBuffers( 780 input [2]devicePointer, unitBytes int, stream unsafe.Pointer, copyFunc func(to, from unsafe.Pointer)) (output [2]devicePointer) { 781 782 output = [2]devicePointer{ 783 deviceAllocate(bc.resultCapacity*unitBytes, bc.device), 784 deviceAllocate(bc.resultCapacity*unitBytes, bc.device), 785 } 786 787 if copyFunc != nil { 788 copyFunc(output[0].getPointer(), input[0].getPointer()) 789 } 790 791 deviceFreeAndSetNil(&input[0]) 792 deviceFreeAndSetNil(&input[1]) 793 return 794 } 795 796 // doProfile checks the corresponding profileName against query parameter 797 // and do cuda profiling for this action if name matches. 798 func (qc *AQLQueryContext) doProfile(action func(), profileName string, stream unsafe.Pointer) { 799 if qc.Profiling == profileName { 800 // explicit waiting for cuda stream to avoid profiling previous actions. 801 memutils.WaitForCudaStream(stream, qc.Device) 802 utils.GetQueryLogger().Infof("Starting cuda profiler for %s", profileName) 803 memutils.CudaProfilerStart() 804 defer func() { 805 // explicit waiting for cuda stream to wait for completion of current action. 806 memutils.WaitForCudaStream(stream, qc.Device) 807 utils.GetQueryLogger().Infof("Stopping cuda profiler for %s", profileName) 808 memutils.CudaProfilerStop() 809 }() 810 } 811 action() 812 } 813 814 // processBatch allocates device memory and starts async input data 815 // transferring to device memory. It then invokes previousBatchExecutor 816 // asynchronously to process the previous batch. When both async operations 817 // finish, it prepares for the current batch execution and returns it as 818 // a function closure to be invoked later. customFilterExecutor is the executor 819 // to apply custom filters for live batch and archive batch. 820 func (qc *AQLQueryContext) processBatch( 821 batch *memstore.Batch, batchID int32, batchSize int, transferFunc batchTransferExecutor, 822 customFilterFunc customFilterExecutor, previousBatchExecutor BatchExecutor, needToUnlockBatch bool) BatchExecutor { 823 defer func() { 824 if needToUnlockBatch { 825 batch.RUnlock() 826 } 827 }() 828 829 if qc.Debug { 830 // Finish executing previous batch first to avoid timeline overlapping 831 qc.runBatchExecutor(previousBatchExecutor, false) 832 previousBatchExecutor = NewDummyBatchExecutor() 833 } 834 835 // reset stats. 836 qc.OOPK.currentBatch.stats = oopkBatchStats{ 837 batchID: batchID, 838 timings: make(map[stageName]float64), 839 } 840 start := utils.Now() 841 842 // Async transfer. 843 stream := qc.cudaStreams[0] 844 deviceSlices, hostVPs, firstColumn, startRow, totalBytes, numTransfers := transferFunc(stream) 845 qc.OOPK.currentBatch.stats.bytesTransferred += totalBytes 846 qc.OOPK.currentBatch.stats.numTransferCalls += numTransfers 847 848 qc.reportTimingForCurrentBatch(stream, &start, transferTiming) 849 850 // Async execute the previous batch. 851 executionDone := make(chan struct{ error }, 1) 852 go func() { 853 defer func() { 854 if r := recover(); r != nil { 855 var err error 856 // find out exactly what the error was and set err 857 switch x := r.(type) { 858 case string: 859 err = utils.StackError(nil, x) 860 case error: 861 err = utils.StackError(x, "Panic happens when executing query") 862 default: 863 err = utils.StackError(nil, "Panic happens when executing query %v", x) 864 } 865 executionDone <- struct{ error }{err} 866 } 867 }() 868 qc.runBatchExecutor(previousBatchExecutor, false) 869 executionDone <- struct{ error }{} 870 }() 871 872 // Wait for data transfer of the current batch. 873 memutils.WaitForCudaStream(stream, qc.Device) 874 875 for _, vp := range hostVPs { 876 if vp != nil { 877 // only archive vector party will be returned after transfer function 878 vp.(memCom.ArchiveVectorParty).Release() 879 } 880 } 881 882 if needToUnlockBatch { 883 batch.RUnlock() 884 needToUnlockBatch = false 885 } 886 887 // Wait for execution of the previous batch. 888 res := <-executionDone 889 if res.error != nil { 890 // column data transfer for current batch is done 891 // need release current batch's column data before panic 892 for _, column := range deviceSlices { 893 deviceFreeAndSetNil(&column.basePtr) 894 } 895 panic(res.error) 896 } 897 898 if qc.OOPK.done { 899 // if the query is already satisfied in the middle, we can skip next batch and return 900 for _, column := range deviceSlices { 901 deviceFreeAndSetNil(&column.basePtr) 902 } 903 return NewDummyBatchExecutor() 904 } 905 906 // no prefilter slicing in livebatch, startRow is always 0 907 qc.OOPK.currentBatch.size = batchSize 908 qc.OOPK.currentBatch.prepareForFiltering(deviceSlices, firstColumn, startRow, stream) 909 910 qc.reportTimingForCurrentBatch(stream, &start, prepareForFilteringTiming) 911 912 return NewBatchExecutor(qc, batchID, customFilterFunc, stream) 913 } 914 915 // prefilterSlice does the following: 916 // 1. binary search for prefilter values following the matched sort column order 917 // 2. record matched index range on these matched sort columns 918 // 3. binary search on unmatched compressed columns for the row number range 919 // 4. index slice on uncompressed columns for the row number range 920 // 5. align/pad all slices to be pushed 921 func (qc *AQLQueryContext) prefilterSlice(vp memCom.ArchiveVectorParty, prefilterIndex, startRow, endRow int) (int, int, memCom.HostVectorPartySlice) { 922 startIndex, endIndex := 0, vp.GetLength() 923 924 unmatchedColumn := false 925 scanner := qc.TableScanners[0] 926 if prefilterIndex < len(scanner.EqualityPrefilterValues) { 927 // matched equality filter 928 filterValue := scanner.EqualityPrefilterValues[prefilterIndex] 929 startRow, endRow, startIndex, endIndex = vp.SliceByValue(startRow, endRow, unsafe.Pointer(&filterValue)) 930 } else if prefilterIndex == len(scanner.EqualityPrefilterValues) { 931 // matched range filter 932 // lower bound 933 filterValue := scanner.RangePrefilterValues[0] 934 boundaryType := scanner.RangePrefilterBoundaries[0] 935 936 if boundaryType != noBoundary { 937 lowerStartRow, lowerEndRow, lowerStartIndex, lowerEndIndex := vp.SliceByValue(startRow, endRow, unsafe.Pointer(&filterValue)) 938 if boundaryType == inclusiveBoundary { 939 startRow, startIndex = lowerStartRow, lowerStartIndex 940 } else { 941 startRow, startIndex = lowerEndRow, lowerEndIndex 942 } 943 } else { 944 // treat as unmatchedColumn when there is one range filter missing 945 unmatchedColumn = true 946 } 947 948 // SliceByValue of upperBound 949 filterValue = scanner.RangePrefilterValues[1] 950 boundaryType = scanner.RangePrefilterBoundaries[1] 951 if boundaryType != noBoundary { 952 upperStartRow, upperEndRow, upperStartIndex, upperEndIndex := vp.SliceByValue(startRow, endRow, unsafe.Pointer(&filterValue)) 953 if boundaryType == inclusiveBoundary { 954 endRow, endIndex = upperEndRow, upperEndIndex 955 } else { 956 endRow, endIndex = upperStartRow, upperStartIndex 957 } 958 } else { 959 // treat as unmatchedColumn when there is one range filter missing 960 unmatchedColumn = true 961 } 962 } else { 963 unmatchedColumn = true 964 } 965 966 if unmatchedColumn { 967 // unmatched columns, simply slice based on row number range 968 startIndex, endIndex = vp.SliceIndex(startRow, endRow) 969 } 970 971 return startRow, endRow, vp.(memstore.TransferableVectorParty).GetHostVectorPartySlice(startIndex, endIndex-startIndex) 972 } 973 974 // calculateMemoryRequirement estimate memory requirement for batch data. 975 func (qc *AQLQueryContext) calculateMemoryRequirement(memStore memstore.MemStore) int { 976 // keep track of max requirement for batch 977 maxBytesRequired := 0 978 979 //TODO(jians): hard code hll query memory requirement here for now, 980 //we can track memory usage 981 //based on table, dimensions, duration to do estimation 982 if qc.OOPK.IsHLL() { 983 return hllQueryRequiredMemoryInMB 984 } 985 986 for _, shardID := range qc.TableScanners[0].Shards { 987 shard, err := memStore.GetTableShard(qc.Query.Table, shardID) 988 if err != nil { 989 qc.Error = utils.StackError(err, "failed to get shard %d for table %s", 990 shardID, qc.Query.Table) 991 return -1 992 } 993 994 var archiveStore *memstore.ArchiveStoreVersion 995 var cutoff uint32 996 if shard.Schema.Schema.IsFactTable { 997 archiveStore = shard.ArchiveStore.GetCurrentVersion() 998 cutoff = archiveStore.ArchivingCutoff 999 } 1000 1001 // estimate live batch memory usage 1002 if qc.toTime == nil || cutoff < uint32(qc.toTime.Time.Unix()) { 1003 batchIDs, _ := shard.LiveStore.GetBatchIDs() 1004 1005 // find first non null batch and estimate. 1006 for _, batchID := range batchIDs { 1007 liveBatch := shard.LiveStore.GetBatchForRead(batchID) 1008 if liveBatch != nil { 1009 batchBytes := qc.estimateLiveBatchMemoryUsage(liveBatch) 1010 liveBatch.RUnlock() 1011 1012 if batchBytes > maxBytesRequired { 1013 maxBytesRequired = batchBytes 1014 } 1015 break 1016 } 1017 } 1018 } 1019 1020 // estimate archive batch memory usage 1021 if archiveStore != nil { 1022 if qc.fromTime == nil || cutoff > uint32(qc.fromTime.Time.Unix()) { 1023 scanner := qc.TableScanners[0] 1024 for batchID := scanner.ArchiveBatchIDStart; batchID < scanner.ArchiveBatchIDEnd; batchID++ { 1025 archiveBatch := archiveStore.RequestBatch(int32(batchID)) 1026 if archiveBatch == nil || archiveBatch.Size == 0 { 1027 continue 1028 } 1029 isFirstOrLast := batchID == scanner.ArchiveBatchIDStart || batchID == scanner.ArchiveBatchIDEnd-1 1030 batchBytes := qc.estimateArchiveBatchMemoryUsage(archiveBatch, isFirstOrLast) 1031 if batchBytes > maxBytesRequired { 1032 maxBytesRequired = batchBytes 1033 } 1034 } 1035 } 1036 archiveStore.Users.Done() 1037 } 1038 shard.Users.Done() 1039 } 1040 1041 maxBytesRequired += qc.calculateForeignTableMemUsage(memStore) 1042 return maxBytesRequired 1043 } 1044 1045 // estimateLiveBatchMemoryUsage estimate the GPU memory usage for live batches 1046 func (qc *AQLQueryContext) estimateLiveBatchMemoryUsage(batch *memstore.LiveBatch) int { 1047 columnMemUsage := 0 1048 for _, columnID := range qc.TableScanners[0].Columns { 1049 sourceVP := batch.Columns[columnID] 1050 if sourceVP == nil { 1051 continue 1052 } 1053 columnMemUsage += int(sourceVP.GetBytes()) 1054 } 1055 1056 totalBytes := qc.estimateMemUsageForBatch(batch.Capacity, columnMemUsage) 1057 utils.GetQueryLogger().Debugf("Live batch %+v needs memory: %d", batch, totalBytes) 1058 return totalBytes 1059 } 1060 1061 // estimateArchiveBatchMemoryUsage estimate the GPU memory usage for archive batch 1062 func (qc *AQLQueryContext) estimateArchiveBatchMemoryUsage(batch *memstore.ArchiveBatch, isFirstOrLast bool) int { 1063 if batch == nil { 1064 return 0 1065 } 1066 1067 columnMemUsage := 0 1068 var firstColumnSize int 1069 startRow, endRow := 0, batch.Size 1070 var hostSlice memCom.HostVectorPartySlice 1071 1072 matchedColumnUsages := columnUsedByAllBatches 1073 if isFirstOrLast { 1074 matchedColumnUsages |= columnUsedByFirstArchiveBatch | columnUsedByLastArchiveBatch 1075 } 1076 1077 prefilterIndex := 0 1078 for i := len(qc.TableScanners[0].Columns) - 1; i >= 0; i-- { 1079 columnID := qc.TableScanners[0].Columns[i] 1080 usage := qc.TableScanners[0].ColumnUsages[columnID] 1081 // TODO(cdavid): only read metadata when estimate query memory requirement. 1082 sourceVP := batch.RequestVectorParty(columnID) 1083 sourceVP.WaitForDiskLoad() 1084 1085 if usage&matchedColumnUsages != 0 || usage&columnUsedByPrefilter != 0 { 1086 startRow, endRow, hostSlice = qc.prefilterSlice(sourceVP, prefilterIndex, startRow, endRow) 1087 prefilterIndex++ 1088 if usage&matchedColumnUsages != 0 { 1089 columnMemUsage += hostSlice.ValueBytes + hostSlice.NullBytes + hostSlice.CountBytes 1090 firstColumnSize = hostSlice.Length 1091 } 1092 } 1093 sourceVP.Release() 1094 } 1095 1096 totalBytes := qc.estimateMemUsageForBatch(firstColumnSize, columnMemUsage) 1097 utils.GetQueryLogger().Debugf("Archive batch %d needs memory: %d", batch.BatchID, totalBytes) 1098 return totalBytes 1099 } 1100 1101 // estimateMemUsageForBatch calculates memory usage including: 1102 // * Index vector 1103 // * Predicate vector 1104 // * Dimension 1105 // * Measurement 1106 // * Sort (hash/index) 1107 // * Reduce 1108 func (qc *AQLQueryContext) estimateMemUsageForBatch(firstColumnSize int, columnMemUsage int) (memUsage int) { 1109 // 1. columnMemUsage 1110 memUsageBeforeAgg := columnMemUsage 1111 1112 // 2. index vector memory usage (4 bytes each) 1113 memUsageBeforeAgg += firstColumnSize * 4 1114 1115 // 3. predicate memory usage (1 byte each) 1116 memUsageBeforeAgg += firstColumnSize 1117 1118 // 4. record id vector for foreign table (8 bytes each recordID) 1119 memUsageBeforeAgg += firstColumnSize * 8 * len(qc.OOPK.foreignTables) 1120 1121 // 5. expression eval memory (max scratch space) 1122 memUsageBeforeAgg += qc.estimateExpressionEvaluationMemUsage(firstColumnSize) 1123 1124 // 6. geoPredicateVector 1125 if qc.OOPK.geoIntersection != nil { 1126 memUsageBeforeAgg += firstColumnSize * 4 * 2 1127 } 1128 1129 // 7. max(memUsageBeforeAgg, sortReduceMemoryUsage) 1130 memUsage = int(math.Max(float64(memUsageBeforeAgg), float64(estimateSortReduceMemUsage(firstColumnSize)))) 1131 1132 // 8. Dimension vector memory usage (input + output) 1133 memUsage += firstColumnSize * qc.OOPK.DimRowBytes * 2 1134 1135 // 9. Measure vector memory usage (input + output) 1136 memUsage += firstColumnSize * qc.OOPK.MeasureBytes * 2 1137 1138 return 1139 } 1140 1141 // memory usage duration expression (filter, dimension, measure) evaluation 1142 func (qc *AQLQueryContext) estimateExpressionEvaluationMemUsage(inputSize int) (memUsage int) { 1143 // filter expression evaluation 1144 for _, filter := range qc.OOPK.MainTableCommonFilters { 1145 _, maxExpMemUsage := estimateScratchSpaceMemUsage(filter, inputSize, true) 1146 utils.GetQueryLogger().Debugf("Filter %+v: maxExpMemUsage=%d", filter, maxExpMemUsage) 1147 memUsage = int(math.Max(float64(memUsage), float64(maxExpMemUsage))) 1148 } 1149 1150 for _, filter := range qc.OOPK.ForeignTableCommonFilters { 1151 _, maxExpMemUsage := estimateScratchSpaceMemUsage(filter, inputSize, true) 1152 utils.GetQueryLogger().Debugf("Filter %+v: maxExpMemUsage=%d", filter, maxExpMemUsage) 1153 memUsage = int(math.Max(float64(memUsage), float64(maxExpMemUsage))) 1154 } 1155 1156 // dimension expression evaluation 1157 for _, dimension := range qc.OOPK.Dimensions { 1158 _, maxExpMemUsage := estimateScratchSpaceMemUsage(dimension, inputSize, true) 1159 utils.GetQueryLogger().Debugf("Dimension %+v: maxExpMemUsage=%d", dimension, maxExpMemUsage) 1160 memUsage = int(math.Max(float64(memUsage), float64(maxExpMemUsage))) 1161 } 1162 1163 // measure expression evaluation 1164 _, maxExpMemUsage := estimateScratchSpaceMemUsage(qc.OOPK.Measure, inputSize, true) 1165 utils.GetQueryLogger().Debugf("Measure %+v: maxExpMemUsage=%d", qc.OOPK.Measure, maxExpMemUsage) 1166 memUsage = int(math.Max(float64(memUsage), float64(maxExpMemUsage))) 1167 1168 return memUsage 1169 } 1170 1171 // Note: we only calculate Sort memory usage 1172 // since sort memory usage is larger than reduce 1173 // and we only care about the maximum 1174 func estimateSortReduceMemUsage(inputSize int) (memUsage int) { 1175 // dimension index vector 1176 // 4 byte for uint32 1177 // 2 vectors for input and output 1178 memUsage += inputSize * 4 * 2 1179 // hash vector 1180 // 8 byte for uint64 hash value 1181 // 2 vectors for input and output 1182 memUsage += inputSize * 8 * 2 1183 // we sort dim index values as value, and hash value as key 1184 memUsage += inputSize * (8 + 4) 1185 return 1186 } 1187 1188 // estimateScratchSpaceMemUsage calculates memory usage for an expression 1189 func estimateScratchSpaceMemUsage(exp expr.Expr, firstColumnSize int, isRoot bool) (int, int) { 1190 var currentMemUsage int 1191 var maxMemUsage int 1192 1193 switch e := exp.(type) { 1194 case *expr.ParenExpr: 1195 return estimateScratchSpaceMemUsage(e.Expr, firstColumnSize, isRoot) 1196 case *expr.UnaryExpr: 1197 childCurrentMemUsage, childMaxMemUsage := estimateScratchSpaceMemUsage(e.Expr, firstColumnSize, false) 1198 if !isRoot { 1199 currentMemUsage = firstColumnSize * 5 1200 } 1201 maxMemUsage = int(math.Max(float64(childCurrentMemUsage+currentMemUsage), float64(childMaxMemUsage))) 1202 return currentMemUsage, maxMemUsage 1203 case *expr.BinaryExpr: 1204 lhsCurrentMemUsage, lhsMaxMemUsage := estimateScratchSpaceMemUsage(e.LHS, firstColumnSize, false) 1205 rhsCurrentMemUsage, rhsMaxMemUsage := estimateScratchSpaceMemUsage(e.RHS, firstColumnSize, false) 1206 1207 if !isRoot { 1208 currentMemUsage = firstColumnSize * 5 1209 } 1210 1211 childrenMaxMemUsage := math.Max(float64(lhsMaxMemUsage), float64(rhsMaxMemUsage)) 1212 maxMemUsage = int(math.Max(float64(currentMemUsage+lhsCurrentMemUsage+rhsCurrentMemUsage), float64(childrenMaxMemUsage))) 1213 1214 return currentMemUsage, maxMemUsage 1215 default: 1216 return 0, 0 1217 } 1218 } 1219 1220 // calculateForeignTableMemUsage returns how much device memory is needed for foreign table 1221 func (qc *AQLQueryContext) calculateForeignTableMemUsage(memStore memstore.MemStore) int { 1222 var memUsage int 1223 1224 for joinTableID, join := range qc.Query.Joins { 1225 // join only support dimension table for now 1226 // and dimension table is not shared 1227 shard, err := memStore.GetTableShard(join.Table, 0) 1228 if err != nil { 1229 qc.Error = utils.StackError(err, "Failed to get shard for table %s, shard: %d", join.Table, 0) 1230 return 0 1231 } 1232 1233 // only need live store for dimension table 1234 batchIDs, _ := shard.LiveStore.GetBatchIDs() 1235 1236 // primary key 1237 memUsage += int(shard.LiveStore.PrimaryKey.AllocatedBytes()) 1238 1239 // VPs 1240 for _, batchID := range batchIDs { 1241 batch := shard.LiveStore.GetBatchForRead(batchID) 1242 if batch == nil { 1243 continue 1244 } 1245 1246 for _, columnID := range qc.TableScanners[joinTableID+1].Columns { 1247 usage := qc.TableScanners[joinTableID+1].ColumnUsages[columnID] 1248 if usage&(columnUsedByAllBatches|columnUsedByLiveBatches) != 0 { 1249 sourceVP := batch.Columns[columnID] 1250 if sourceVP == nil { 1251 continue 1252 } 1253 memUsage += int(sourceVP.GetBytes()) 1254 } 1255 } 1256 batch.RUnlock() 1257 } 1258 shard.Users.Done() 1259 } 1260 1261 return memUsage 1262 } 1263 1264 // FindDeviceForQuery calls device manager to find a device for the query 1265 func (qc *AQLQueryContext) FindDeviceForQuery(memStore memstore.MemStore, preferredDevice int, 1266 deviceManager *DeviceManager, timeout int) { 1267 memoryRequired := qc.calculateMemoryRequirement(memStore) 1268 if qc.Error != nil { 1269 return 1270 } 1271 1272 qc.OOPK.DeviceMemoryRequirement = memoryRequired 1273 1274 waitStart := utils.Now() 1275 device := deviceManager.FindDevice(qc.Query, memoryRequired, preferredDevice, timeout) 1276 if device == -1 { 1277 qc.Error = utils.StackError(nil, "Unable to find device to run this query") 1278 } 1279 qc.OOPK.DurationWaitedForDevice = utils.Now().Sub(waitStart) 1280 qc.Device = device 1281 } 1282 1283 func (qc *AQLQueryContext) runBatchExecutor(e BatchExecutor, isLastBatch bool) { 1284 start := utils.Now() 1285 e.preExec(isLastBatch, start) 1286 1287 e.filter() 1288 1289 e.join() 1290 1291 e.project() 1292 1293 e.reduce() 1294 1295 e.postExec(start) 1296 } 1297 1298 // copyHostToDevice copy vector party slice to device vector party slice 1299 func copyHostToDevice(vps memCom.HostVectorPartySlice, deviceVPSlice deviceVectorPartySlice, stream unsafe.Pointer, device int) (bytesCopied, numTransfers int) { 1300 if vps.ValueBytes > 0 { 1301 memutils.AsyncCopyHostToDevice( 1302 deviceVPSlice.values.getPointer(), vps.Values, vps.ValueBytes, 1303 stream, device) 1304 bytesCopied += vps.ValueBytes 1305 numTransfers++ 1306 } 1307 if vps.NullBytes > 0 { 1308 memutils.AsyncCopyHostToDevice( 1309 deviceVPSlice.nulls.getPointer(), vps.Nulls, vps.NullBytes, 1310 stream, device) 1311 bytesCopied += vps.NullBytes 1312 numTransfers++ 1313 } 1314 if vps.CountBytes > 0 { 1315 memutils.AsyncCopyHostToDevice( 1316 deviceVPSlice.counts.getPointer(), vps.Counts, vps.CountBytes, 1317 stream, device) 1318 bytesCopied += vps.CountBytes 1319 numTransfers++ 1320 } 1321 return 1322 } 1323 1324 func hostToDeviceColumn(hostColumn memCom.HostVectorPartySlice, device int) deviceVectorPartySlice { 1325 deviceColumn := deviceVectorPartySlice{ 1326 length: hostColumn.Length, 1327 valueType: hostColumn.ValueType, 1328 defaultValue: hostColumn.DefaultValue, 1329 valueStartIndex: hostColumn.ValueStartIndex, 1330 nullStartIndex: hostColumn.NullStartIndex, 1331 countStartIndex: hostColumn.CountStartIndex, 1332 } 1333 totalColumnBytes := hostColumn.ValueBytes + hostColumn.NullBytes + hostColumn.CountBytes 1334 1335 if totalColumnBytes > 0 { 1336 deviceColumn.basePtr = deviceAllocate(totalColumnBytes, device) 1337 if hostColumn.Counts != nil { 1338 deviceColumn.counts = deviceColumn.basePtr.offset(0) 1339 } 1340 1341 if hostColumn.Nulls != nil { 1342 deviceColumn.nulls = deviceColumn.basePtr.offset(hostColumn.CountBytes) 1343 } 1344 1345 deviceColumn.values = deviceColumn.basePtr.offset( 1346 hostColumn.NullBytes + hostColumn.CountBytes) 1347 } 1348 return deviceColumn 1349 } 1350 1351 // shouldSkipLiveBatch will determine whether we can skip processing a live batch by checking time filter and 1352 // eligible main table common filters. The batch must be non nil. 1353 func (qc *AQLQueryContext) shouldSkipLiveBatch(b *memstore.LiveBatch) bool { 1354 candidatesFilters := []expr.Expr{qc.OOPK.TimeFilters[0], qc.OOPK.TimeFilters[1]} 1355 candidatesFilters = append(candidatesFilters, qc.OOPK.MainTableCommonFilters...) 1356 for _, filter := range candidatesFilters { 1357 if shouldSkipLiveBatchWithFilter(b, filter) { 1358 return true 1359 } 1360 } 1361 return false 1362 } 1363 1364 // shouldSkipLiveBatchWithFilter will check max and min for the corresponding column against the filter express and 1365 // determines whether we should skip processing this live batch. 1366 // Following constraints apply: 1367 // 1. Filter must be on main table. 1368 // 2. Filter must be a binary expression. 1369 // 3. OPs must be one of (EQ, GTE,GE,LTE,LE). 1370 // 4. One side of the expr must be VarRef 1371 // 5. Another side of the xpr must be NumericalLiteral 1372 // 6. ColumnType must be UInt32 1373 func shouldSkipLiveBatchWithFilter(b *memstore.LiveBatch, filter expr.Expr) bool { 1374 if filter == nil { 1375 return false 1376 } 1377 1378 if binExpr, ok := filter.(*expr.BinaryExpr); ok { 1379 var columnExpr *expr.VarRef 1380 var numExpr *expr.NumberLiteral 1381 op := binExpr.Op 1382 switch op { 1383 case expr.GTE, expr.GT, expr.LT, expr.LTE, expr.EQ: 1384 default: 1385 return false 1386 } 1387 // First try lhs VarRef, rhs Num. 1388 lhsVarRef, lhsOK := binExpr.LHS.(*expr.VarRef) 1389 rhsNum, rhsOK := binExpr.RHS.(*expr.NumberLiteral) 1390 if lhsOK && rhsOK { 1391 columnExpr = lhsVarRef 1392 numExpr = rhsNum 1393 } else { 1394 // Then try rhs VarRef, lhs Num. 1395 lhsNum, lhsOK := binExpr.LHS.(*expr.NumberLiteral) 1396 rhsVarRef, rhsOK := binExpr.RHS.(*expr.VarRef) 1397 if lhsOK && rhsOK { 1398 // Swap column to the left and number to right. 1399 columnExpr = rhsVarRef 1400 numExpr = lhsNum 1401 // Invert the OP. 1402 switch op { 1403 case expr.GTE: 1404 op = expr.LTE 1405 case expr.GT: 1406 op = expr.LT 1407 case expr.LTE: 1408 op = expr.GTE 1409 case expr.LT: 1410 op = expr.GT 1411 } 1412 } 1413 } 1414 1415 if columnExpr != nil && numExpr != nil { 1416 // Time filters and main table filters are guaranteed to be on main table. 1417 vp := b.Columns[columnExpr.ColumnID] 1418 if vp == nil { 1419 return true 1420 } 1421 1422 if columnExpr.DataType != memCom.Uint32 { 1423 return false 1424 } 1425 1426 num := int64(numExpr.Int) 1427 minUint32, maxUint32 := vp.(memCom.LiveVectorParty).GetMinMaxValue() 1428 min, max := int64(minUint32), int64(maxUint32) 1429 switch op { 1430 case expr.GTE: 1431 return max < num 1432 case expr.GT: 1433 return max <= num 1434 case expr.LTE: 1435 return min > num 1436 case expr.LT: 1437 return min >= num 1438 case expr.EQ: 1439 return min > num || max < num 1440 } 1441 } 1442 } 1443 return false 1444 }