github.com/rohankumardubey/aresdb@v0.0.2-0.20190517170215-e54e3ca06b9c/query/aql_processor.go (about)

     1  //  Copyright (c) 2017-2018 Uber Technologies, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package query
    16  
    17  import (
    18  	"math"
    19  	"unsafe"
    20  
    21  	"encoding/binary"
    22  	"github.com/uber/aresdb/memstore"
    23  	memCom "github.com/uber/aresdb/memstore/common"
    24  	"github.com/uber/aresdb/memutils"
    25  	queryCom "github.com/uber/aresdb/query/common"
    26  	"github.com/uber/aresdb/query/expr"
    27  	"github.com/uber/aresdb/utils"
    28  	"time"
    29  )
    30  
    31  const (
    32  	hllQueryRequiredMemoryInMB = 10 * 1024
    33  )
    34  
    35  // batchTransferExecutor defines the type of the functor to transfer a live batch or a archive batch
    36  // from host memory to device memory. hostVPs will be the columns to be released after transfer. startRow
    37  // is used to slice the vector party.
    38  type batchTransferExecutor func(stream unsafe.Pointer) (deviceColumns []deviceVectorPartySlice,
    39  	hostVPs []memCom.VectorParty, firstColumn, startRow, totalBytes, numTransfers int)
    40  
    41  // customFilterExecutor is the functor to apply custom filters depends on the batch type. For archive batch,
    42  // the custom filter will be the time filter and will only be applied to first or last batch. For live batch,
    43  // the custom filters will be the cutoff time filter if cutoff is larger than 0, pre-filters and time filters.
    44  type customFilterExecutor func(stream unsafe.Pointer)
    45  
    46  // ProcessQuery processes the compiled query and executes it on GPU.
    47  func (qc *AQLQueryContext) ProcessQuery(memStore memstore.MemStore) {
    48  	defer func() {
    49  		if r := recover(); r != nil {
    50  			// find out exactly what the error was and set err
    51  			switch x := r.(type) {
    52  			case string:
    53  				qc.Error = utils.StackError(nil, x)
    54  			case error:
    55  				qc.Error = utils.StackError(x, "Panic happens when processing query")
    56  			default:
    57  				qc.Error = utils.StackError(nil, "Panic happens when processing query %v", x)
    58  			}
    59  			utils.GetLogger().Error("Releasing device memory after panic")
    60  			qc.Release()
    61  		}
    62  	}()
    63  
    64  	qc.cudaStreams[0] = memutils.CreateCudaStream(qc.Device)
    65  	qc.cudaStreams[1] = memutils.CreateCudaStream(qc.Device)
    66  	qc.OOPK.currentBatch.device = qc.Device
    67  	qc.OOPK.LiveBatchStats = oopkQueryStats{
    68  		Name2Stage: make(map[stageName]*oopkStageSummaryStats),
    69  	}
    70  	qc.OOPK.ArchiveBatchStats = oopkQueryStats{
    71  		Name2Stage: make(map[stageName]*oopkStageSummaryStats),
    72  	}
    73  
    74  	previousBatchExecutor := NewDummyBatchExecutor()
    75  
    76  	start := utils.Now()
    77  	for joinTableID, join := range qc.Query.Joins {
    78  		qc.prepareForeignTable(memStore, joinTableID, join)
    79  		if qc.Error != nil {
    80  			return
    81  		}
    82  	}
    83  	qc.reportTiming(qc.cudaStreams[0], &start, prepareForeignTableTiming)
    84  
    85  	qc.prepareTimezoneTable(memStore)
    86  	if qc.Error != nil {
    87  		return
    88  	}
    89  
    90  	// prepare geo intersection
    91  	if qc.OOPK.geoIntersection != nil {
    92  		shapeExists := qc.prepareForGeoIntersect(memStore)
    93  		if qc.Error != nil {
    94  			return
    95  		}
    96  		if !shapeExists {
    97  			// if no shape exist and geo check for point in shape
    98  			// no need to continue processing batch
    99  			if qc.OOPK.geoIntersection.inOrOut {
   100  				return
   101  			}
   102  			// if no shape exist and geo check for point not in shape
   103  			// no need to do geo intersection
   104  			qc.OOPK.geoIntersection = nil
   105  		}
   106  	}
   107  
   108  	for _, shardID := range qc.TableScanners[0].Shards {
   109  		previousBatchExecutor = qc.processShard(memStore, shardID, previousBatchExecutor)
   110  		if qc.Error != nil {
   111  			return
   112  		}
   113  		if qc.OOPK.done {
   114  			break
   115  		}
   116  	}
   117  
   118  	// query execution for last batch.
   119  	qc.runBatchExecutor(previousBatchExecutor, true)
   120  
   121  	// this code snippet does the followings:
   122  	// 1. write stats to log.
   123  	// 2. allocate host buffer for result and copy the result from device to host.
   124  	// 3. clean up device status buffers if no panic.
   125  	if qc.Debug {
   126  		qc.OOPK.LiveBatchStats.writeToLog()
   127  		qc.OOPK.ArchiveBatchStats.writeToLog()
   128  	}
   129  
   130  	start = utils.Now()
   131  	if qc.Error == nil {
   132  		// Copy the result to host memory.
   133  		qc.OOPK.ResultSize = qc.OOPK.currentBatch.resultSize
   134  		if qc.OOPK.IsHLL() {
   135  			qc.HLLQueryResult, qc.Error = qc.PostprocessAsHLLData()
   136  		} else {
   137  			// copy dimensions
   138  			qc.OOPK.dimensionVectorH = memutils.HostAlloc(qc.OOPK.ResultSize * qc.OOPK.DimRowBytes)
   139  			asyncCopyDimensionVector(qc.OOPK.dimensionVectorH, qc.OOPK.currentBatch.dimensionVectorD[0].getPointer(), qc.OOPK.ResultSize, 0,
   140  				qc.OOPK.NumDimsPerDimWidth, qc.OOPK.ResultSize, qc.OOPK.currentBatch.resultCapacity,
   141  				memutils.AsyncCopyDeviceToHost, qc.cudaStreams[0], qc.Device)
   142  			if !qc.isNonAggregationQuery {
   143  				// copy measures
   144  				qc.OOPK.measureVectorH = memutils.HostAlloc(qc.OOPK.ResultSize * qc.OOPK.MeasureBytes)
   145  				memutils.AsyncCopyDeviceToHost(
   146  					qc.OOPK.measureVectorH, qc.OOPK.currentBatch.measureVectorD[0].getPointer(),
   147  					qc.OOPK.ResultSize*qc.OOPK.MeasureBytes, qc.cudaStreams[0], qc.Device)
   148  			}
   149  			memutils.WaitForCudaStream(qc.cudaStreams[0], qc.Device)
   150  		}
   151  	}
   152  	qc.reportTiming(qc.cudaStreams[0], &start, resultTransferTiming)
   153  	qc.cleanUpDeviceStatus()
   154  	qc.reportTiming(nil, &start, finalCleanupTiming)
   155  }
   156  
   157  func (qc *AQLQueryContext) processShard(memStore memstore.MemStore, shardID int, previousBatchExecutor BatchExecutor) BatchExecutor {
   158  	var liveRecordsProcessed, archiveRecordsProcessed, liveBatchProcessed, archiveBatchProcessed, liveBytesTransferred, archiveBytesTransferred int
   159  	shard, err := memStore.GetTableShard(qc.Query.Table, shardID)
   160  	if err != nil {
   161  		qc.Error = utils.StackError(err, "failed to get shard %d for table %s",
   162  			shardID, qc.Query.Table)
   163  		return previousBatchExecutor
   164  	}
   165  	defer shard.Users.Done()
   166  
   167  	var archiveStore *memstore.ArchiveStoreVersion
   168  	var cutoff uint32
   169  	if shard.Schema.Schema.IsFactTable {
   170  		archiveStore = shard.ArchiveStore.GetCurrentVersion()
   171  		defer archiveStore.Users.Done()
   172  		cutoff = archiveStore.ArchivingCutoff
   173  	}
   174  
   175  	// Process live batches.
   176  	if qc.toTime == nil || cutoff < uint32(qc.toTime.Time.Unix()) {
   177  		batchIDs, numRecordsInLastBatch := shard.LiveStore.GetBatchIDs()
   178  		for i, batchID := range batchIDs {
   179  			if qc.OOPK.done {
   180  				break
   181  			}
   182  			batch := shard.LiveStore.GetBatchForRead(batchID)
   183  			if batch == nil {
   184  				continue
   185  			}
   186  
   187  			// For now, dimension table does not persist min and max therefore
   188  			// we can only skip live batch for fact table.
   189  			// TODO: Persist min/max/numTrues when snapshotting.
   190  			if shard.Schema.Schema.IsFactTable && qc.shouldSkipLiveBatch(batch) {
   191  				batch.RUnlock()
   192  				qc.OOPK.LiveBatchStats.NumBatchSkipped++
   193  				continue
   194  			}
   195  
   196  			liveBatchProcessed++
   197  			size := batch.Capacity
   198  			if i == len(batchIDs)-1 {
   199  				size = numRecordsInLastBatch
   200  			}
   201  			liveRecordsProcessed += size
   202  			previousBatchExecutor = qc.processBatch(&batch.Batch,
   203  				batchID,
   204  				size,
   205  				qc.transferLiveBatch(batch, size),
   206  				qc.liveBatchCustomFilterExecutor(cutoff), previousBatchExecutor, true)
   207  			qc.cudaStreams[0], qc.cudaStreams[1] = qc.cudaStreams[1], qc.cudaStreams[0]
   208  			liveBytesTransferred += qc.OOPK.currentBatch.stats.bytesTransferred
   209  		}
   210  	}
   211  
   212  	// Process archive batches.
   213  	if archiveStore != nil && (qc.fromTime == nil || cutoff > uint32(qc.fromTime.Time.Unix())) {
   214  		scanner := qc.TableScanners[0]
   215  		for batchID := scanner.ArchiveBatchIDStart; batchID < scanner.ArchiveBatchIDEnd; batchID++ {
   216  			if qc.OOPK.done {
   217  				break
   218  			}
   219  			archiveBatch := archiveStore.RequestBatch(int32(batchID))
   220  			if archiveBatch.Size == 0 {
   221  				qc.OOPK.ArchiveBatchStats.NumBatchSkipped++
   222  				continue
   223  			}
   224  			isFirstOrLast := batchID == scanner.ArchiveBatchIDStart || batchID == scanner.ArchiveBatchIDEnd-1
   225  			previousBatchExecutor = qc.processBatch(
   226  				&archiveBatch.Batch,
   227  				int32(batchID),
   228  				archiveBatch.Size,
   229  				qc.transferArchiveBatch(archiveBatch, isFirstOrLast),
   230  				qc.archiveBatchCustomFilterExecutor(isFirstOrLast),
   231  				previousBatchExecutor, false)
   232  			archiveRecordsProcessed += archiveBatch.Size
   233  			archiveBatchProcessed++
   234  			qc.cudaStreams[0], qc.cudaStreams[1] = qc.cudaStreams[1], qc.cudaStreams[0]
   235  			archiveBytesTransferred += qc.OOPK.currentBatch.stats.bytesTransferred
   236  		}
   237  	}
   238  	utils.GetReporter(qc.Query.Table, shardID).GetCounter(utils.QueryLiveRecordsProcessed).Inc(int64(liveRecordsProcessed))
   239  	utils.GetReporter(qc.Query.Table, shardID).GetCounter(utils.QueryArchiveRecordsProcessed).Inc(int64(archiveRecordsProcessed))
   240  	utils.GetReporter(qc.Query.Table, shardID).GetCounter(utils.QueryLiveBatchProcessed).Inc(int64(liveBatchProcessed))
   241  	utils.GetReporter(qc.Query.Table, shardID).GetCounter(utils.QueryArchiveBatchProcessed).Inc(int64(archiveBatchProcessed))
   242  	utils.GetReporter(qc.Query.Table, shardID).GetCounter(utils.QueryLiveBytesTransferred).Inc(int64(liveBytesTransferred))
   243  	utils.GetReporter(qc.Query.Table, shardID).GetCounter(utils.QueryArchiveBytesTransferred).Inc(int64(archiveBytesTransferred))
   244  
   245  	return previousBatchExecutor
   246  }
   247  
   248  // Release releases all device memory it allocated. It **should only called** when any errors happens while the query is
   249  // processed.
   250  func (qc *AQLQueryContext) Release() {
   251  	// release device memory for processing current batch.
   252  	qc.OOPK.currentBatch.cleanupBeforeAggregation()
   253  	qc.OOPK.currentBatch.swapResultBufferForNextBatch()
   254  	qc.cleanUpDeviceStatus()
   255  	qc.ReleaseHostResultsBuffers()
   256  }
   257  
   258  // CleanUpDevice cleans up the device status including
   259  //  1. clean up the device buffer for storing results.
   260  //  2. clean up the cuda streams
   261  func (qc *AQLQueryContext) cleanUpDeviceStatus() {
   262  	// clean up foreign table memory after query
   263  	for _, foreignTable := range qc.OOPK.foreignTables {
   264  		qc.cleanUpForeignTable(foreignTable)
   265  	}
   266  	qc.OOPK.foreignTables = nil
   267  
   268  	// release geo pointers
   269  	if qc.OOPK.geoIntersection != nil {
   270  		deviceFreeAndSetNil(&qc.OOPK.geoIntersection.shapeLatLongs)
   271  	}
   272  
   273  	// Destroy streams
   274  	memutils.DestroyCudaStream(qc.cudaStreams[0], qc.Device)
   275  	memutils.DestroyCudaStream(qc.cudaStreams[1], qc.Device)
   276  	qc.cudaStreams = [2]unsafe.Pointer{nil, nil}
   277  
   278  	// Clean up the device result buffers.
   279  	qc.OOPK.currentBatch.cleanupDeviceResultBuffers()
   280  
   281  	// Clean up timezone lookup buffer.
   282  	deviceFreeAndSetNil(&qc.OOPK.currentBatch.timezoneLookupD)
   283  }
   284  
   285  // clean up foreign table
   286  func (qc *AQLQueryContext) cleanUpForeignTable(table *foreignTable) {
   287  	if table != nil {
   288  		deviceFreeAndSetNil(&table.devicePrimaryKeyPtr)
   289  		for _, batch := range table.batches {
   290  			for _, column := range batch {
   291  				deviceFreeAndSetNil(&column.basePtr)
   292  			}
   293  		}
   294  		table.batches = nil
   295  	}
   296  }
   297  
   298  // getGeoShapeLatLongSlice format GeoShapeGo into slices of float32 for query purpose
   299  // Lats and Longs are stored in the format as [a1,a2,...an,a1,MaxFloat32,b1,bz,...bn]
   300  // refer to time_series_aggregate.h for GeoShape struct
   301  func getGeoShapeLatLongSlice(shapesLats, shapesLongs []float32, gs memCom.GeoShapeGo) ([]float32, []float32, int) {
   302  	numPoints := 0
   303  	for i, polygon := range gs.Polygons {
   304  		if len(polygon) > 0 && i > 0 {
   305  			// write place holder at start of polygon
   306  			shapesLats = append(shapesLats, math.MaxFloat32)
   307  			shapesLongs = append(shapesLongs, math.MaxFloat32)
   308  			// FLT_MAX as placeholder for each polygon
   309  			numPoints++
   310  		}
   311  		for _, point := range polygon {
   312  			shapesLats = append(shapesLats, point[0])
   313  			shapesLongs = append(shapesLongs, point[1])
   314  			numPoints++
   315  		}
   316  	}
   317  	return shapesLats, shapesLongs, numPoints
   318  }
   319  
   320  func (qc *AQLQueryContext) prepareForGeoIntersect(memStore memstore.MemStore) (shapeExists bool) {
   321  	tableScanner := qc.TableScanners[qc.OOPK.geoIntersection.shapeTableID]
   322  	shapeColumnID := qc.OOPK.geoIntersection.shapeColumnID
   323  	tableName := tableScanner.Schema.Schema.Name
   324  	// geo table is not sharded
   325  	shard, err := memStore.GetTableShard(tableName, 0)
   326  	if err != nil {
   327  		qc.Error = utils.StackError(err, "Failed to get shard for table %s, shard: %d", tableName, 0)
   328  		return
   329  	}
   330  	defer shard.Users.Done()
   331  
   332  	numPointsPerShape := make([]int32, 0, len(qc.OOPK.geoIntersection.shapeUUIDs))
   333  	qc.OOPK.geoIntersection.validShapeUUIDs = make([]string, 0, len(qc.OOPK.geoIntersection.shapeUUIDs))
   334  	var shapesLats, shapesLongs []float32
   335  	var numPoints, totalNumPoints int
   336  	for _, uuid := range qc.OOPK.geoIntersection.shapeUUIDs {
   337  		recordID, found := shard.LiveStore.LookupKey([]string{uuid})
   338  		if found {
   339  			batch := shard.LiveStore.GetBatchForRead(recordID.BatchID)
   340  			if batch != nil {
   341  				shapeValue := batch.GetDataValue(int(recordID.Index), shapeColumnID)
   342  				// compiler should have verified the geo column GeoShape type
   343  				shapesLats, shapesLongs, numPoints = getGeoShapeLatLongSlice(shapesLats, shapesLongs, *(shapeValue.GoVal.(*memCom.GeoShapeGo)))
   344  				if numPoints > 0 {
   345  					totalNumPoints += numPoints
   346  					numPointsPerShape = append(numPointsPerShape, int32(numPoints))
   347  					qc.OOPK.geoIntersection.validShapeUUIDs = append(qc.OOPK.geoIntersection.validShapeUUIDs, uuid)
   348  					shapeExists = true
   349  				}
   350  				batch.RUnlock()
   351  			}
   352  		}
   353  	}
   354  
   355  	if !shapeExists {
   356  		return
   357  	}
   358  
   359  	numValidShapes := len(numPointsPerShape)
   360  	shapeIndexs := make([]uint8, totalNumPoints)
   361  	pointIndex := 0
   362  	for shapeIndex, numPoints := range numPointsPerShape {
   363  		for i := 0; i < int(numPoints); i++ {
   364  			shapeIndexs[pointIndex] = uint8(shapeIndex)
   365  			pointIndex++
   366  		}
   367  	}
   368  
   369  	// allocate memory for lats, longs (float32) and numPoints (int32) device vectors
   370  	latsPtrD := deviceAllocate(totalNumPoints*4*2+totalNumPoints, qc.Device)
   371  	longsPtrD := latsPtrD.offset(totalNumPoints * 4)
   372  	shapeIndexsD := longsPtrD.offset(totalNumPoints * 4)
   373  
   374  	memutils.AsyncCopyHostToDevice(latsPtrD.getPointer(), unsafe.Pointer(&shapesLats[0]), totalNumPoints*4, qc.cudaStreams[0], qc.Device)
   375  	memutils.AsyncCopyHostToDevice(longsPtrD.getPointer(), unsafe.Pointer(&shapesLongs[0]), totalNumPoints*4, qc.cudaStreams[0], qc.Device)
   376  	memutils.AsyncCopyHostToDevice(shapeIndexsD.getPointer(), unsafe.Pointer(&shapeIndexs[0]), totalNumPoints, qc.cudaStreams[0], qc.Device)
   377  
   378  	qc.OOPK.geoIntersection.shapeLatLongs = latsPtrD
   379  	qc.OOPK.geoIntersection.numShapes = numValidShapes
   380  	qc.OOPK.geoIntersection.totalNumPoints = totalNumPoints
   381  	return
   382  }
   383  
   384  // prepare foreign table (allocate and transfer memory) before processing
   385  func (qc *AQLQueryContext) prepareForeignTable(memStore memstore.MemStore, joinTableID int, join Join) {
   386  	ft := qc.OOPK.foreignTables[joinTableID]
   387  	if ft == nil {
   388  		return
   389  	}
   390  
   391  	// join only support dimension table for now
   392  	// and dimension table is not shared
   393  	shard, err := memStore.GetTableShard(join.Table, 0)
   394  	if err != nil {
   395  		qc.Error = utils.StackError(err, "Failed to get shard for table %s, shard: %d", join.Table, 0)
   396  		return
   397  	}
   398  	defer shard.Users.Done()
   399  
   400  	// only need live store for dimension table
   401  	batchIDs, numRecordsInLastBatch := shard.LiveStore.GetBatchIDs()
   402  	ft.numRecordsInLastBatch = numRecordsInLastBatch
   403  	deviceBatches := make([][]deviceVectorPartySlice, len(batchIDs))
   404  
   405  	// transfer primary key
   406  	hostPrimaryKeyData := shard.LiveStore.PrimaryKey.LockForTransfer()
   407  	devicePrimaryKeyPtr := deviceAllocate(hostPrimaryKeyData.NumBytes, qc.Device)
   408  	memutils.AsyncCopyHostToDevice(devicePrimaryKeyPtr.getPointer(), hostPrimaryKeyData.Data, hostPrimaryKeyData.NumBytes, qc.cudaStreams[0], qc.Device)
   409  	memutils.WaitForCudaStream(qc.cudaStreams[0], qc.Device)
   410  	ft.hostPrimaryKeyData = hostPrimaryKeyData
   411  	ft.devicePrimaryKeyPtr = devicePrimaryKeyPtr
   412  	shard.LiveStore.PrimaryKey.UnlockAfterTransfer()
   413  
   414  	// allocate device memory
   415  	for i, batchID := range batchIDs {
   416  		batch := shard.LiveStore.GetBatchForRead(batchID)
   417  		if batch == nil {
   418  			continue
   419  		}
   420  		batchIndex := batchID - memstore.BaseBatchID
   421  		deviceBatches[batchIndex] = make([]deviceVectorPartySlice, len(qc.TableScanners[joinTableID+1].Columns))
   422  
   423  		size := batch.Capacity
   424  		if i == len(batchIDs)-1 {
   425  			size = numRecordsInLastBatch
   426  		}
   427  		for i, columnID := range qc.TableScanners[joinTableID+1].Columns {
   428  			usage := qc.TableScanners[joinTableID+1].ColumnUsages[columnID]
   429  			if usage&(columnUsedByAllBatches|columnUsedByLiveBatches) != 0 {
   430  				sourceVP := batch.Columns[columnID]
   431  				if sourceVP == nil {
   432  					continue
   433  				}
   434  
   435  				hostVPSlice := sourceVP.(memstore.TransferableVectorParty).GetHostVectorPartySlice(0, size)
   436  				deviceBatches[batchIndex][i] = hostToDeviceColumn(hostVPSlice, qc.Device)
   437  				copyHostToDevice(hostVPSlice, deviceBatches[batchIndex][i], qc.cudaStreams[0], qc.Device)
   438  			}
   439  		}
   440  		memutils.WaitForCudaStream(qc.cudaStreams[0], qc.Device)
   441  		batch.RUnlock()
   442  	}
   443  	ft.batches = deviceBatches
   444  }
   445  
   446  // prepareTimezoneTable
   447  func (qc *AQLQueryContext) prepareTimezoneTable(store memstore.MemStore) {
   448  	if qc.timezoneTable.tableColumn == "" {
   449  		return
   450  	}
   451  
   452  	// Timezone table
   453  	timezoneTableName := utils.GetConfig().Query.TimezoneTable.TableName
   454  	schema, err := store.GetSchema(timezoneTableName)
   455  	if err != nil {
   456  		qc.Error = err
   457  		return
   458  	}
   459  	if schema == nil {
   460  		qc.Error = utils.StackError(nil, "unknown timezone table %s", timezoneTableName)
   461  		return
   462  	}
   463  
   464  	timer := utils.GetRootReporter().GetTimer(utils.TimezoneLookupTableCreationTime)
   465  	start := utils.Now()
   466  	defer func() {
   467  		duration := utils.Now().Sub(start)
   468  		timer.Record(duration)
   469  	}()
   470  
   471  	schema.RLock()
   472  	defer schema.RUnlock()
   473  
   474  	if tzDict, found := schema.EnumDicts[qc.timezoneTable.tableColumn]; found {
   475  		lookUp := make([]int16, len(tzDict.ReverseDict))
   476  		for i := range lookUp {
   477  			if loc, err := time.LoadLocation(tzDict.ReverseDict[i]); err == nil {
   478  				_, offset := time.Now().In(loc).Zone()
   479  				lookUp[i] = int16(offset)
   480  			} else {
   481  				qc.Error = utils.StackError(err, "error parsing timezone")
   482  				return
   483  			}
   484  		}
   485  		sizeInBytes := binary.Size(lookUp)
   486  		lookupPtr := deviceAllocate(sizeInBytes, qc.Device)
   487  		memutils.AsyncCopyHostToDevice(lookupPtr.getPointer(), unsafe.Pointer(&lookUp[0]), sizeInBytes, qc.cudaStreams[0], qc.Device)
   488  		qc.OOPK.currentBatch.timezoneLookupD = lookupPtr
   489  		qc.OOPK.currentBatch.timezoneLookupDSize = len(lookUp)
   490  	} else {
   491  		qc.Error = utils.StackError(nil, "unknown timezone column %s", qc.timezoneTable.tableColumn)
   492  		return
   493  	}
   494  
   495  }
   496  
   497  // transferLiveBatch returns a functor to transfer a live batch to device memory. The size parameter will be either the
   498  // size of the batch or num records in last batch. hostColumns will always be empty since we should not release a vector
   499  // party of a live batch. Start row will always be zero as well.
   500  func (qc *AQLQueryContext) transferLiveBatch(batch *memstore.LiveBatch, size int) batchTransferExecutor {
   501  	return func(stream unsafe.Pointer) (deviceColumns []deviceVectorPartySlice, hostVPs []memCom.VectorParty,
   502  		firstColumn, startRow, totalBytes, numTransfers int) {
   503  		// Allocate column inputs.
   504  		firstColumn = -1
   505  		deviceColumns = make([]deviceVectorPartySlice, len(qc.TableScanners[0].Columns))
   506  		for i, columnID := range qc.TableScanners[0].Columns {
   507  			usage := qc.TableScanners[0].ColumnUsages[columnID]
   508  			if usage&(columnUsedByAllBatches|columnUsedByLiveBatches) != 0 {
   509  				if firstColumn < 0 {
   510  					firstColumn = i
   511  				}
   512  				sourceVP := batch.Columns[columnID]
   513  				if sourceVP == nil {
   514  					continue
   515  				}
   516  
   517  				hostColumn := sourceVP.(memstore.TransferableVectorParty).GetHostVectorPartySlice(0, size)
   518  				deviceColumns[i] = hostToDeviceColumn(hostColumn, qc.Device)
   519  				b, t := copyHostToDevice(hostColumn, deviceColumns[i], stream, qc.Device)
   520  				totalBytes += b
   521  				numTransfers += t
   522  			}
   523  		}
   524  		return
   525  	}
   526  }
   527  
   528  // liveBatchTimeFilterExecutor returns a functor to apply custom time filters to live batch.
   529  func (qc *AQLQueryContext) liveBatchCustomFilterExecutor(cutoff uint32) customFilterExecutor {
   530  	return func(stream unsafe.Pointer) {
   531  		// cutoff filter evaluation.
   532  		// only apply to fact table where cutoff > 0
   533  		if cutoff > 0 {
   534  			qc.OOPK.currentBatch.processExpression(
   535  				qc.createCutoffTimeFilter(cutoff), nil,
   536  				qc.TableScanners, qc.OOPK.foreignTables, stream, qc.Device, qc.OOPK.currentBatch.filterAction)
   537  		}
   538  
   539  		// time filter evaluation
   540  		for _, filter := range qc.OOPK.TimeFilters {
   541  			if filter != nil {
   542  				qc.OOPK.currentBatch.processExpression(filter, nil,
   543  					qc.TableScanners, qc.OOPK.foreignTables, stream, qc.Device, qc.OOPK.currentBatch.filterAction)
   544  			}
   545  		}
   546  
   547  		// prefilter evaluation
   548  		for _, filter := range qc.OOPK.Prefilters {
   549  			qc.OOPK.currentBatch.processExpression(filter, nil,
   550  				qc.TableScanners, qc.OOPK.foreignTables, stream, qc.Device, qc.OOPK.currentBatch.filterAction)
   551  		}
   552  	}
   553  }
   554  
   555  // transferArchiveBatch returns the functor to transfer an archive batch to device memory. We will need to release
   556  // hostColumns after transfer completes.
   557  func (qc *AQLQueryContext) transferArchiveBatch(batch *memstore.ArchiveBatch,
   558  	isFirstOrLast bool) batchTransferExecutor {
   559  	return func(stream unsafe.Pointer) (deviceSlices []deviceVectorPartySlice, hostVPs []memCom.VectorParty,
   560  		firstColumn, startRow, totalBytes, numTransfers int) {
   561  		matchedColumnUsages := columnUsedByAllBatches
   562  		if isFirstOrLast {
   563  			matchedColumnUsages |= columnUsedByFirstArchiveBatch | columnUsedByLastArchiveBatch
   564  		}
   565  
   566  		// Request columns, prefilter-slicing, allocate column inputs.
   567  		firstColumn = -1
   568  		hostVPs = make([]memCom.VectorParty, len(qc.TableScanners[0].Columns))
   569  		hostSlices := make([]memCom.HostVectorPartySlice, len(qc.TableScanners[0].Columns))
   570  		deviceSlices = make([]deviceVectorPartySlice, len(qc.TableScanners[0].Columns))
   571  		endRow := batch.Size
   572  		prefilterIndex := 0
   573  		// Must iterate in reverse order to apply prefilter slicing properly.
   574  		for i := len(qc.TableScanners[0].Columns) - 1; i >= 0; i-- {
   575  			columnID := qc.TableScanners[0].Columns[i]
   576  			usage := qc.TableScanners[0].ColumnUsages[columnID]
   577  
   578  			if usage&matchedColumnUsages != 0 || usage&columnUsedByPrefilter != 0 {
   579  				// Request/pin column from disk and wait.
   580  				vp := batch.RequestVectorParty(columnID)
   581  				vp.WaitForDiskLoad()
   582  
   583  				// prefilter slicing
   584  				startRow, endRow, hostSlices[i] = qc.prefilterSlice(vp, prefilterIndex, startRow, endRow)
   585  				prefilterIndex++
   586  
   587  				if usage&matchedColumnUsages != 0 {
   588  					hostVPs[i] = vp
   589  					firstColumn = i
   590  					deviceSlices[i] = hostToDeviceColumn(hostSlices[i], qc.Device)
   591  				} else {
   592  					vp.Release()
   593  				}
   594  			}
   595  		}
   596  
   597  		for i, dstVPSlice := range deviceSlices {
   598  			columnID := qc.TableScanners[0].Columns[i]
   599  			usage := qc.TableScanners[0].ColumnUsages[columnID]
   600  			if usage&matchedColumnUsages != 0 {
   601  				srcVPSlice := hostSlices[i]
   602  				b, t := copyHostToDevice(srcVPSlice, dstVPSlice, stream, qc.Device)
   603  				totalBytes += b
   604  				numTransfers += t
   605  			}
   606  		}
   607  		return
   608  	}
   609  }
   610  
   611  // archiveBatchCustomFilterExecutor returns a functor to apply custom filter to first or last archive batch.
   612  func (qc *AQLQueryContext) archiveBatchCustomFilterExecutor(isFirstOrLast bool) customFilterExecutor {
   613  	return func(stream unsafe.Pointer) {
   614  		if isFirstOrLast {
   615  			for _, filter := range qc.OOPK.TimeFilters {
   616  				if filter != nil {
   617  					qc.OOPK.currentBatch.processExpression(filter, nil,
   618  						qc.TableScanners, qc.OOPK.foreignTables, stream, qc.Device, qc.OOPK.currentBatch.filterAction)
   619  				}
   620  			}
   621  		}
   622  	}
   623  }
   624  
   625  // helper function for copy dimension vector. Returns the total size of dimension vector.
   626  func asyncCopyDimensionVector(toDimVector, fromDimVector unsafe.Pointer, length, offset int, numDimsPerDimWidth queryCom.DimCountsPerDimWidth,
   627  	toVectorCapacity, fromVectorCapacity int, copyFunc memutils.AsyncMemCopyFunc,
   628  	stream unsafe.Pointer, device int) {
   629  
   630  	ptrFrom, ptrTo := fromDimVector, toDimVector
   631  	numNullVectors := 0
   632  	for _, numDims := range numDimsPerDimWidth {
   633  		numNullVectors += int(numDims)
   634  	}
   635  
   636  	dimBytes := 1 << uint(len(numDimsPerDimWidth)-1)
   637  	bytesToCopy := length * dimBytes
   638  	for _, numDim := range numDimsPerDimWidth {
   639  		for i := 0; i < int(numDim); i++ {
   640  			ptrTemp := utils.MemAccess(ptrTo, dimBytes*offset)
   641  			copyFunc(ptrTemp, ptrFrom, bytesToCopy, stream, device)
   642  			ptrTo = utils.MemAccess(ptrTo, dimBytes*toVectorCapacity)
   643  			ptrFrom = utils.MemAccess(ptrFrom, dimBytes*fromVectorCapacity)
   644  		}
   645  		dimBytes >>= 1
   646  		bytesToCopy = length * dimBytes
   647  	}
   648  
   649  	// copy null bytes
   650  	for i := 0; i < numNullVectors; i++ {
   651  		ptrTemp := utils.MemAccess(ptrTo, offset)
   652  		copyFunc(ptrTemp, ptrFrom, length, stream, device)
   653  		ptrTo = utils.MemAccess(ptrTo, toVectorCapacity)
   654  		ptrFrom = utils.MemAccess(ptrFrom, fromVectorCapacity)
   655  	}
   656  }
   657  
   658  // dimValueVectorSize returns the size of final dim value vector on host side.
   659  func dimValResVectorSize(resultSize int, numDimsPerDimWidth queryCom.DimCountsPerDimWidth) int {
   660  	totalDims := 0
   661  	for _, numDims := range numDimsPerDimWidth {
   662  		totalDims += int(numDims)
   663  	}
   664  
   665  	dimBytes := 1 << uint(len(numDimsPerDimWidth)-1)
   666  	var totalBytes int
   667  	for _, numDims := range numDimsPerDimWidth {
   668  		totalBytes += dimBytes * resultSize * int(numDims)
   669  		dimBytes >>= 1
   670  	}
   671  
   672  	totalBytes += totalDims * resultSize
   673  	return totalBytes
   674  }
   675  
   676  // cleanupDeviceResultBuffers cleans up result buffers and resets result fields.
   677  func (bc *oopkBatchContext) cleanupDeviceResultBuffers() {
   678  	deviceFreeAndSetNil(&bc.dimensionVectorD[0])
   679  	deviceFreeAndSetNil(&bc.dimensionVectorD[1])
   680  
   681  	deviceFreeAndSetNil(&bc.dimIndexVectorD[0])
   682  	deviceFreeAndSetNil(&bc.dimIndexVectorD[1])
   683  
   684  	deviceFreeAndSetNil(&bc.hashVectorD[0])
   685  	deviceFreeAndSetNil(&bc.hashVectorD[1])
   686  
   687  	deviceFreeAndSetNil(&bc.measureVectorD[0])
   688  	deviceFreeAndSetNil(&bc.measureVectorD[1])
   689  
   690  	bc.resultSize = 0
   691  	bc.resultCapacity = 0
   692  }
   693  
   694  // clean up memory not used in final aggregation (sort, reduce, hll)
   695  // before aggregation happen
   696  func (bc *oopkBatchContext) cleanupBeforeAggregation() {
   697  	for _, column := range bc.columns {
   698  		deviceFreeAndSetNil(&column.basePtr)
   699  	}
   700  	bc.columns = nil
   701  
   702  	deviceFreeAndSetNil(&bc.indexVectorD)
   703  	deviceFreeAndSetNil(&bc.predicateVectorD)
   704  	deviceFreeAndSetNil(&bc.geoPredicateVectorD)
   705  
   706  	for _, recordIDsVector := range bc.foreignTableRecordIDsD {
   707  		deviceFreeAndSetNil(&recordIDsVector)
   708  	}
   709  	bc.foreignTableRecordIDsD = nil
   710  
   711  	for _, stackFrame := range bc.exprStackD {
   712  		deviceFreeAndSetNil(&stackFrame[0])
   713  	}
   714  	bc.exprStackD = nil
   715  }
   716  
   717  // swapResultBufferForNextBatch swaps the two
   718  // sets of dim/measure/hash vectors to get ready for the next batch.
   719  func (bc *oopkBatchContext) swapResultBufferForNextBatch() {
   720  	bc.size = 0
   721  	bc.dimensionVectorD[0], bc.dimensionVectorD[1] = bc.dimensionVectorD[1], bc.dimensionVectorD[0]
   722  	bc.measureVectorD[0], bc.measureVectorD[1] = bc.measureVectorD[1], bc.measureVectorD[0]
   723  	bc.hashVectorD[0], bc.hashVectorD[1] = bc.hashVectorD[1], bc.hashVectorD[0]
   724  }
   725  
   726  // prepareForFiltering prepares the input and the index vectors for filtering.
   727  func (bc *oopkBatchContext) prepareForFiltering(
   728  	columns []deviceVectorPartySlice, firstColumn int, startRow int, stream unsafe.Pointer) {
   729  	bc.columns = columns
   730  	bc.startRow = startRow
   731  
   732  	if firstColumn >= 0 {
   733  		bc.size = columns[firstColumn].length
   734  		// Allocate twice of the size to save number of allocations of temporary index vector.
   735  		bc.indexVectorD = deviceAllocate(bc.size*4, bc.device)
   736  		bc.predicateVectorD = deviceAllocate(bc.size, bc.device)
   737  		bc.baseCountD = columns[firstColumn].counts.offset(columns[firstColumn].countStartIndex * 4)
   738  	}
   739  	bc.stats.batchSize = bc.size
   740  }
   741  
   742  // prepareForDimAndMeasureEval ensures that dim/measure vectors have enough
   743  // capacity for bc.resultSize+bc.size.
   744  func (bc *oopkBatchContext) prepareForDimAndMeasureEval(
   745  	dimRowBytes int, measureBytes int, numDimsPerDimWidth queryCom.DimCountsPerDimWidth, isHLL bool, stream unsafe.Pointer) {
   746  	if bc.resultSize+bc.size > bc.resultCapacity {
   747  		oldCapacity := bc.resultCapacity
   748  
   749  		bc.resultCapacity = bc.resultSize + bc.size
   750  		// Extra budget for future proofing.
   751  		bc.resultCapacity += bc.resultCapacity / 8
   752  
   753  		bc.dimensionVectorD = bc.reallocateResultBuffers(bc.dimensionVectorD, dimRowBytes, stream, func(to, from unsafe.Pointer) {
   754  			asyncCopyDimensionVector(to, from, bc.resultSize, 0,
   755  				numDimsPerDimWidth, bc.resultCapacity, oldCapacity,
   756  				memutils.AsyncCopyDeviceToDevice, stream, bc.device)
   757  		})
   758  
   759  		// uint32_t for index value
   760  		bc.dimIndexVectorD = bc.reallocateResultBuffers(bc.dimIndexVectorD, 4, stream, nil)
   761  		// uint64_t for hash value
   762  		// Note: only when aggregate function is hll, we need to reuse vector[0]
   763  		if isHLL {
   764  			bc.hashVectorD = bc.reallocateResultBuffers(bc.hashVectorD, 8, stream, func(to, from unsafe.Pointer) {
   765  				memutils.AsyncCopyDeviceToDevice(to, from, bc.resultSize*8, stream, bc.device)
   766  			})
   767  		} else {
   768  			bc.hashVectorD = bc.reallocateResultBuffers(bc.hashVectorD, 8, stream, nil)
   769  		}
   770  
   771  		bc.measureVectorD = bc.reallocateResultBuffers(bc.measureVectorD, measureBytes, stream, func(to, from unsafe.Pointer) {
   772  			memutils.AsyncCopyDeviceToDevice(to, from, bc.resultSize*measureBytes, stream, bc.device)
   773  		})
   774  	}
   775  }
   776  
   777  // reallocateResultBuffers reallocates the result buffer pair to size
   778  // resultCapacity*unitBytes and copies resultSize*unitBytes from input[0] to output[0].
   779  func (bc *oopkBatchContext) reallocateResultBuffers(
   780  	input [2]devicePointer, unitBytes int, stream unsafe.Pointer, copyFunc func(to, from unsafe.Pointer)) (output [2]devicePointer) {
   781  
   782  	output = [2]devicePointer{
   783  		deviceAllocate(bc.resultCapacity*unitBytes, bc.device),
   784  		deviceAllocate(bc.resultCapacity*unitBytes, bc.device),
   785  	}
   786  
   787  	if copyFunc != nil {
   788  		copyFunc(output[0].getPointer(), input[0].getPointer())
   789  	}
   790  
   791  	deviceFreeAndSetNil(&input[0])
   792  	deviceFreeAndSetNil(&input[1])
   793  	return
   794  }
   795  
   796  // doProfile checks the corresponding profileName against query parameter
   797  // and do cuda profiling for this action if name matches.
   798  func (qc *AQLQueryContext) doProfile(action func(), profileName string, stream unsafe.Pointer) {
   799  	if qc.Profiling == profileName {
   800  		// explicit waiting for cuda stream to avoid profiling previous actions.
   801  		memutils.WaitForCudaStream(stream, qc.Device)
   802  		utils.GetQueryLogger().Infof("Starting cuda profiler for %s", profileName)
   803  		memutils.CudaProfilerStart()
   804  		defer func() {
   805  			// explicit waiting for cuda stream to wait for completion of current action.
   806  			memutils.WaitForCudaStream(stream, qc.Device)
   807  			utils.GetQueryLogger().Infof("Stopping cuda profiler for %s", profileName)
   808  			memutils.CudaProfilerStop()
   809  		}()
   810  	}
   811  	action()
   812  }
   813  
   814  // processBatch allocates device memory and starts async input data
   815  // transferring to device memory. It then invokes previousBatchExecutor
   816  // asynchronously to process the previous batch. When both async operations
   817  // finish, it prepares for the current batch execution and returns it as
   818  // a function closure to be invoked later. customFilterExecutor is the executor
   819  // to apply custom filters for live batch and archive batch.
   820  func (qc *AQLQueryContext) processBatch(
   821  	batch *memstore.Batch, batchID int32, batchSize int, transferFunc batchTransferExecutor,
   822  	customFilterFunc customFilterExecutor, previousBatchExecutor BatchExecutor, needToUnlockBatch bool) BatchExecutor {
   823  	defer func() {
   824  		if needToUnlockBatch {
   825  			batch.RUnlock()
   826  		}
   827  	}()
   828  
   829  	if qc.Debug {
   830  		// Finish executing previous batch first to avoid timeline overlapping
   831  		qc.runBatchExecutor(previousBatchExecutor, false)
   832  		previousBatchExecutor = NewDummyBatchExecutor()
   833  	}
   834  
   835  	// reset stats.
   836  	qc.OOPK.currentBatch.stats = oopkBatchStats{
   837  		batchID: batchID,
   838  		timings: make(map[stageName]float64),
   839  	}
   840  	start := utils.Now()
   841  
   842  	// Async transfer.
   843  	stream := qc.cudaStreams[0]
   844  	deviceSlices, hostVPs, firstColumn, startRow, totalBytes, numTransfers := transferFunc(stream)
   845  	qc.OOPK.currentBatch.stats.bytesTransferred += totalBytes
   846  	qc.OOPK.currentBatch.stats.numTransferCalls += numTransfers
   847  
   848  	qc.reportTimingForCurrentBatch(stream, &start, transferTiming)
   849  
   850  	// Async execute the previous batch.
   851  	executionDone := make(chan struct{ error }, 1)
   852  	go func() {
   853  		defer func() {
   854  			if r := recover(); r != nil {
   855  				var err error
   856  				// find out exactly what the error was and set err
   857  				switch x := r.(type) {
   858  				case string:
   859  					err = utils.StackError(nil, x)
   860  				case error:
   861  					err = utils.StackError(x, "Panic happens when executing query")
   862  				default:
   863  					err = utils.StackError(nil, "Panic happens when executing query %v", x)
   864  				}
   865  				executionDone <- struct{ error }{err}
   866  			}
   867  		}()
   868  		qc.runBatchExecutor(previousBatchExecutor, false)
   869  		executionDone <- struct{ error }{}
   870  	}()
   871  
   872  	// Wait for data transfer of the current batch.
   873  	memutils.WaitForCudaStream(stream, qc.Device)
   874  
   875  	for _, vp := range hostVPs {
   876  		if vp != nil {
   877  			// only archive vector party will be returned after transfer function
   878  			vp.(memCom.ArchiveVectorParty).Release()
   879  		}
   880  	}
   881  
   882  	if needToUnlockBatch {
   883  		batch.RUnlock()
   884  		needToUnlockBatch = false
   885  	}
   886  
   887  	// Wait for execution of the previous batch.
   888  	res := <-executionDone
   889  	if res.error != nil {
   890  		// column data transfer for current batch is done
   891  		// need release current batch's column data before panic
   892  		for _, column := range deviceSlices {
   893  			deviceFreeAndSetNil(&column.basePtr)
   894  		}
   895  		panic(res.error)
   896  	}
   897  
   898  	if qc.OOPK.done {
   899  		// if the query is already satisfied in the middle, we can skip next batch and return
   900  		for _, column := range deviceSlices {
   901  			deviceFreeAndSetNil(&column.basePtr)
   902  		}
   903  		return NewDummyBatchExecutor()
   904  	}
   905  
   906  	// no prefilter slicing in livebatch, startRow is always 0
   907  	qc.OOPK.currentBatch.size = batchSize
   908  	qc.OOPK.currentBatch.prepareForFiltering(deviceSlices, firstColumn, startRow, stream)
   909  
   910  	qc.reportTimingForCurrentBatch(stream, &start, prepareForFilteringTiming)
   911  
   912  	return NewBatchExecutor(qc, batchID, customFilterFunc, stream)
   913  }
   914  
   915  // prefilterSlice does the following:
   916  // 1. binary search for prefilter values following the matched sort column order
   917  // 2. record matched index range on these matched sort columns
   918  // 3. binary search on unmatched compressed columns for the row number range
   919  // 4. index slice on uncompressed columns for the row number range
   920  // 5. align/pad all slices to be pushed
   921  func (qc *AQLQueryContext) prefilterSlice(vp memCom.ArchiveVectorParty, prefilterIndex, startRow, endRow int) (int, int, memCom.HostVectorPartySlice) {
   922  	startIndex, endIndex := 0, vp.GetLength()
   923  
   924  	unmatchedColumn := false
   925  	scanner := qc.TableScanners[0]
   926  	if prefilterIndex < len(scanner.EqualityPrefilterValues) {
   927  		// matched equality filter
   928  		filterValue := scanner.EqualityPrefilterValues[prefilterIndex]
   929  		startRow, endRow, startIndex, endIndex = vp.SliceByValue(startRow, endRow, unsafe.Pointer(&filterValue))
   930  	} else if prefilterIndex == len(scanner.EqualityPrefilterValues) {
   931  		// matched range filter
   932  		// lower bound
   933  		filterValue := scanner.RangePrefilterValues[0]
   934  		boundaryType := scanner.RangePrefilterBoundaries[0]
   935  
   936  		if boundaryType != noBoundary {
   937  			lowerStartRow, lowerEndRow, lowerStartIndex, lowerEndIndex := vp.SliceByValue(startRow, endRow, unsafe.Pointer(&filterValue))
   938  			if boundaryType == inclusiveBoundary {
   939  				startRow, startIndex = lowerStartRow, lowerStartIndex
   940  			} else {
   941  				startRow, startIndex = lowerEndRow, lowerEndIndex
   942  			}
   943  		} else {
   944  			// treat as unmatchedColumn when there is one range filter missing
   945  			unmatchedColumn = true
   946  		}
   947  
   948  		// SliceByValue of upperBound
   949  		filterValue = scanner.RangePrefilterValues[1]
   950  		boundaryType = scanner.RangePrefilterBoundaries[1]
   951  		if boundaryType != noBoundary {
   952  			upperStartRow, upperEndRow, upperStartIndex, upperEndIndex := vp.SliceByValue(startRow, endRow, unsafe.Pointer(&filterValue))
   953  			if boundaryType == inclusiveBoundary {
   954  				endRow, endIndex = upperEndRow, upperEndIndex
   955  			} else {
   956  				endRow, endIndex = upperStartRow, upperStartIndex
   957  			}
   958  		} else {
   959  			// treat as unmatchedColumn when there is one range filter missing
   960  			unmatchedColumn = true
   961  		}
   962  	} else {
   963  		unmatchedColumn = true
   964  	}
   965  
   966  	if unmatchedColumn {
   967  		// unmatched columns, simply slice based on row number range
   968  		startIndex, endIndex = vp.SliceIndex(startRow, endRow)
   969  	}
   970  
   971  	return startRow, endRow, vp.(memstore.TransferableVectorParty).GetHostVectorPartySlice(startIndex, endIndex-startIndex)
   972  }
   973  
   974  // calculateMemoryRequirement estimate memory requirement for batch data.
   975  func (qc *AQLQueryContext) calculateMemoryRequirement(memStore memstore.MemStore) int {
   976  	// keep track of max requirement for batch
   977  	maxBytesRequired := 0
   978  
   979  	//TODO(jians): hard code hll query memory requirement here for now,
   980  	//we can track memory usage
   981  	//based on table, dimensions, duration to do estimation
   982  	if qc.OOPK.IsHLL() {
   983  		return hllQueryRequiredMemoryInMB
   984  	}
   985  
   986  	for _, shardID := range qc.TableScanners[0].Shards {
   987  		shard, err := memStore.GetTableShard(qc.Query.Table, shardID)
   988  		if err != nil {
   989  			qc.Error = utils.StackError(err, "failed to get shard %d for table %s",
   990  				shardID, qc.Query.Table)
   991  			return -1
   992  		}
   993  
   994  		var archiveStore *memstore.ArchiveStoreVersion
   995  		var cutoff uint32
   996  		if shard.Schema.Schema.IsFactTable {
   997  			archiveStore = shard.ArchiveStore.GetCurrentVersion()
   998  			cutoff = archiveStore.ArchivingCutoff
   999  		}
  1000  
  1001  		// estimate live batch memory usage
  1002  		if qc.toTime == nil || cutoff < uint32(qc.toTime.Time.Unix()) {
  1003  			batchIDs, _ := shard.LiveStore.GetBatchIDs()
  1004  
  1005  			// find first non null batch and estimate.
  1006  			for _, batchID := range batchIDs {
  1007  				liveBatch := shard.LiveStore.GetBatchForRead(batchID)
  1008  				if liveBatch != nil {
  1009  					batchBytes := qc.estimateLiveBatchMemoryUsage(liveBatch)
  1010  					liveBatch.RUnlock()
  1011  
  1012  					if batchBytes > maxBytesRequired {
  1013  						maxBytesRequired = batchBytes
  1014  					}
  1015  					break
  1016  				}
  1017  			}
  1018  		}
  1019  
  1020  		// estimate archive batch memory usage
  1021  		if archiveStore != nil {
  1022  			if qc.fromTime == nil || cutoff > uint32(qc.fromTime.Time.Unix()) {
  1023  				scanner := qc.TableScanners[0]
  1024  				for batchID := scanner.ArchiveBatchIDStart; batchID < scanner.ArchiveBatchIDEnd; batchID++ {
  1025  					archiveBatch := archiveStore.RequestBatch(int32(batchID))
  1026  					if archiveBatch == nil || archiveBatch.Size == 0 {
  1027  						continue
  1028  					}
  1029  					isFirstOrLast := batchID == scanner.ArchiveBatchIDStart || batchID == scanner.ArchiveBatchIDEnd-1
  1030  					batchBytes := qc.estimateArchiveBatchMemoryUsage(archiveBatch, isFirstOrLast)
  1031  					if batchBytes > maxBytesRequired {
  1032  						maxBytesRequired = batchBytes
  1033  					}
  1034  				}
  1035  			}
  1036  			archiveStore.Users.Done()
  1037  		}
  1038  		shard.Users.Done()
  1039  	}
  1040  
  1041  	maxBytesRequired += qc.calculateForeignTableMemUsage(memStore)
  1042  	return maxBytesRequired
  1043  }
  1044  
  1045  // estimateLiveBatchMemoryUsage estimate the GPU memory usage for live batches
  1046  func (qc *AQLQueryContext) estimateLiveBatchMemoryUsage(batch *memstore.LiveBatch) int {
  1047  	columnMemUsage := 0
  1048  	for _, columnID := range qc.TableScanners[0].Columns {
  1049  		sourceVP := batch.Columns[columnID]
  1050  		if sourceVP == nil {
  1051  			continue
  1052  		}
  1053  		columnMemUsage += int(sourceVP.GetBytes())
  1054  	}
  1055  
  1056  	totalBytes := qc.estimateMemUsageForBatch(batch.Capacity, columnMemUsage)
  1057  	utils.GetQueryLogger().Debugf("Live batch %+v needs memory: %d", batch, totalBytes)
  1058  	return totalBytes
  1059  }
  1060  
  1061  // estimateArchiveBatchMemoryUsage estimate the GPU memory usage for archive batch
  1062  func (qc *AQLQueryContext) estimateArchiveBatchMemoryUsage(batch *memstore.ArchiveBatch, isFirstOrLast bool) int {
  1063  	if batch == nil {
  1064  		return 0
  1065  	}
  1066  
  1067  	columnMemUsage := 0
  1068  	var firstColumnSize int
  1069  	startRow, endRow := 0, batch.Size
  1070  	var hostSlice memCom.HostVectorPartySlice
  1071  
  1072  	matchedColumnUsages := columnUsedByAllBatches
  1073  	if isFirstOrLast {
  1074  		matchedColumnUsages |= columnUsedByFirstArchiveBatch | columnUsedByLastArchiveBatch
  1075  	}
  1076  
  1077  	prefilterIndex := 0
  1078  	for i := len(qc.TableScanners[0].Columns) - 1; i >= 0; i-- {
  1079  		columnID := qc.TableScanners[0].Columns[i]
  1080  		usage := qc.TableScanners[0].ColumnUsages[columnID]
  1081  		// TODO(cdavid): only read metadata when estimate query memory requirement.
  1082  		sourceVP := batch.RequestVectorParty(columnID)
  1083  		sourceVP.WaitForDiskLoad()
  1084  
  1085  		if usage&matchedColumnUsages != 0 || usage&columnUsedByPrefilter != 0 {
  1086  			startRow, endRow, hostSlice = qc.prefilterSlice(sourceVP, prefilterIndex, startRow, endRow)
  1087  			prefilterIndex++
  1088  			if usage&matchedColumnUsages != 0 {
  1089  				columnMemUsage += hostSlice.ValueBytes + hostSlice.NullBytes + hostSlice.CountBytes
  1090  				firstColumnSize = hostSlice.Length
  1091  			}
  1092  		}
  1093  		sourceVP.Release()
  1094  	}
  1095  
  1096  	totalBytes := qc.estimateMemUsageForBatch(firstColumnSize, columnMemUsage)
  1097  	utils.GetQueryLogger().Debugf("Archive batch %d needs memory: %d", batch.BatchID, totalBytes)
  1098  	return totalBytes
  1099  }
  1100  
  1101  // estimateMemUsageForBatch calculates memory usage including:
  1102  // * Index vector
  1103  // * Predicate vector
  1104  // * Dimension
  1105  // * Measurement
  1106  // * Sort (hash/index)
  1107  // * Reduce
  1108  func (qc *AQLQueryContext) estimateMemUsageForBatch(firstColumnSize int, columnMemUsage int) (memUsage int) {
  1109  	// 1. columnMemUsage
  1110  	memUsageBeforeAgg := columnMemUsage
  1111  
  1112  	// 2. index vector memory usage (4 bytes each)
  1113  	memUsageBeforeAgg += firstColumnSize * 4
  1114  
  1115  	// 3. predicate memory usage (1 byte each)
  1116  	memUsageBeforeAgg += firstColumnSize
  1117  
  1118  	// 4. record id vector for foreign table (8 bytes each recordID)
  1119  	memUsageBeforeAgg += firstColumnSize * 8 * len(qc.OOPK.foreignTables)
  1120  
  1121  	// 5. expression eval memory (max scratch space)
  1122  	memUsageBeforeAgg += qc.estimateExpressionEvaluationMemUsage(firstColumnSize)
  1123  
  1124  	// 6. geoPredicateVector
  1125  	if qc.OOPK.geoIntersection != nil {
  1126  		memUsageBeforeAgg += firstColumnSize * 4 * 2
  1127  	}
  1128  
  1129  	// 7. max(memUsageBeforeAgg, sortReduceMemoryUsage)
  1130  	memUsage = int(math.Max(float64(memUsageBeforeAgg), float64(estimateSortReduceMemUsage(firstColumnSize))))
  1131  
  1132  	// 8. Dimension vector memory usage (input + output)
  1133  	memUsage += firstColumnSize * qc.OOPK.DimRowBytes * 2
  1134  
  1135  	// 9. Measure vector memory usage (input + output)
  1136  	memUsage += firstColumnSize * qc.OOPK.MeasureBytes * 2
  1137  
  1138  	return
  1139  }
  1140  
  1141  // memory usage duration expression (filter, dimension, measure) evaluation
  1142  func (qc *AQLQueryContext) estimateExpressionEvaluationMemUsage(inputSize int) (memUsage int) {
  1143  	// filter expression evaluation
  1144  	for _, filter := range qc.OOPK.MainTableCommonFilters {
  1145  		_, maxExpMemUsage := estimateScratchSpaceMemUsage(filter, inputSize, true)
  1146  		utils.GetQueryLogger().Debugf("Filter %+v: maxExpMemUsage=%d", filter, maxExpMemUsage)
  1147  		memUsage = int(math.Max(float64(memUsage), float64(maxExpMemUsage)))
  1148  	}
  1149  
  1150  	for _, filter := range qc.OOPK.ForeignTableCommonFilters {
  1151  		_, maxExpMemUsage := estimateScratchSpaceMemUsage(filter, inputSize, true)
  1152  		utils.GetQueryLogger().Debugf("Filter %+v: maxExpMemUsage=%d", filter, maxExpMemUsage)
  1153  		memUsage = int(math.Max(float64(memUsage), float64(maxExpMemUsage)))
  1154  	}
  1155  
  1156  	// dimension expression evaluation
  1157  	for _, dimension := range qc.OOPK.Dimensions {
  1158  		_, maxExpMemUsage := estimateScratchSpaceMemUsage(dimension, inputSize, true)
  1159  		utils.GetQueryLogger().Debugf("Dimension %+v: maxExpMemUsage=%d", dimension, maxExpMemUsage)
  1160  		memUsage = int(math.Max(float64(memUsage), float64(maxExpMemUsage)))
  1161  	}
  1162  
  1163  	// measure expression evaluation
  1164  	_, maxExpMemUsage := estimateScratchSpaceMemUsage(qc.OOPK.Measure, inputSize, true)
  1165  	utils.GetQueryLogger().Debugf("Measure %+v: maxExpMemUsage=%d", qc.OOPK.Measure, maxExpMemUsage)
  1166  	memUsage = int(math.Max(float64(memUsage), float64(maxExpMemUsage)))
  1167  
  1168  	return memUsage
  1169  }
  1170  
  1171  // Note: we only calculate Sort memory usage
  1172  // since sort memory usage is larger than reduce
  1173  // and we only care about the maximum
  1174  func estimateSortReduceMemUsage(inputSize int) (memUsage int) {
  1175  	// dimension index vector
  1176  	// 4 byte for uint32
  1177  	// 2 vectors for input and output
  1178  	memUsage += inputSize * 4 * 2
  1179  	// hash vector
  1180  	// 8 byte for uint64 hash value
  1181  	// 2 vectors for input and output
  1182  	memUsage += inputSize * 8 * 2
  1183  	// we sort dim index values as value, and hash value as key
  1184  	memUsage += inputSize * (8 + 4)
  1185  	return
  1186  }
  1187  
  1188  // estimateScratchSpaceMemUsage calculates memory usage for an expression
  1189  func estimateScratchSpaceMemUsage(exp expr.Expr, firstColumnSize int, isRoot bool) (int, int) {
  1190  	var currentMemUsage int
  1191  	var maxMemUsage int
  1192  
  1193  	switch e := exp.(type) {
  1194  	case *expr.ParenExpr:
  1195  		return estimateScratchSpaceMemUsage(e.Expr, firstColumnSize, isRoot)
  1196  	case *expr.UnaryExpr:
  1197  		childCurrentMemUsage, childMaxMemUsage := estimateScratchSpaceMemUsage(e.Expr, firstColumnSize, false)
  1198  		if !isRoot {
  1199  			currentMemUsage = firstColumnSize * 5
  1200  		}
  1201  		maxMemUsage = int(math.Max(float64(childCurrentMemUsage+currentMemUsage), float64(childMaxMemUsage)))
  1202  		return currentMemUsage, maxMemUsage
  1203  	case *expr.BinaryExpr:
  1204  		lhsCurrentMemUsage, lhsMaxMemUsage := estimateScratchSpaceMemUsage(e.LHS, firstColumnSize, false)
  1205  		rhsCurrentMemUsage, rhsMaxMemUsage := estimateScratchSpaceMemUsage(e.RHS, firstColumnSize, false)
  1206  
  1207  		if !isRoot {
  1208  			currentMemUsage = firstColumnSize * 5
  1209  		}
  1210  
  1211  		childrenMaxMemUsage := math.Max(float64(lhsMaxMemUsage), float64(rhsMaxMemUsage))
  1212  		maxMemUsage = int(math.Max(float64(currentMemUsage+lhsCurrentMemUsage+rhsCurrentMemUsage), float64(childrenMaxMemUsage)))
  1213  
  1214  		return currentMemUsage, maxMemUsage
  1215  	default:
  1216  		return 0, 0
  1217  	}
  1218  }
  1219  
  1220  // calculateForeignTableMemUsage returns how much device memory is needed for foreign table
  1221  func (qc *AQLQueryContext) calculateForeignTableMemUsage(memStore memstore.MemStore) int {
  1222  	var memUsage int
  1223  
  1224  	for joinTableID, join := range qc.Query.Joins {
  1225  		// join only support dimension table for now
  1226  		// and dimension table is not shared
  1227  		shard, err := memStore.GetTableShard(join.Table, 0)
  1228  		if err != nil {
  1229  			qc.Error = utils.StackError(err, "Failed to get shard for table %s, shard: %d", join.Table, 0)
  1230  			return 0
  1231  		}
  1232  
  1233  		// only need live store for dimension table
  1234  		batchIDs, _ := shard.LiveStore.GetBatchIDs()
  1235  
  1236  		// primary key
  1237  		memUsage += int(shard.LiveStore.PrimaryKey.AllocatedBytes())
  1238  
  1239  		// VPs
  1240  		for _, batchID := range batchIDs {
  1241  			batch := shard.LiveStore.GetBatchForRead(batchID)
  1242  			if batch == nil {
  1243  				continue
  1244  			}
  1245  
  1246  			for _, columnID := range qc.TableScanners[joinTableID+1].Columns {
  1247  				usage := qc.TableScanners[joinTableID+1].ColumnUsages[columnID]
  1248  				if usage&(columnUsedByAllBatches|columnUsedByLiveBatches) != 0 {
  1249  					sourceVP := batch.Columns[columnID]
  1250  					if sourceVP == nil {
  1251  						continue
  1252  					}
  1253  					memUsage += int(sourceVP.GetBytes())
  1254  				}
  1255  			}
  1256  			batch.RUnlock()
  1257  		}
  1258  		shard.Users.Done()
  1259  	}
  1260  
  1261  	return memUsage
  1262  }
  1263  
  1264  // FindDeviceForQuery calls device manager to find a device for the query
  1265  func (qc *AQLQueryContext) FindDeviceForQuery(memStore memstore.MemStore, preferredDevice int,
  1266  	deviceManager *DeviceManager, timeout int) {
  1267  	memoryRequired := qc.calculateMemoryRequirement(memStore)
  1268  	if qc.Error != nil {
  1269  		return
  1270  	}
  1271  
  1272  	qc.OOPK.DeviceMemoryRequirement = memoryRequired
  1273  
  1274  	waitStart := utils.Now()
  1275  	device := deviceManager.FindDevice(qc.Query, memoryRequired, preferredDevice, timeout)
  1276  	if device == -1 {
  1277  		qc.Error = utils.StackError(nil, "Unable to find device to run this query")
  1278  	}
  1279  	qc.OOPK.DurationWaitedForDevice = utils.Now().Sub(waitStart)
  1280  	qc.Device = device
  1281  }
  1282  
  1283  func (qc *AQLQueryContext) runBatchExecutor(e BatchExecutor, isLastBatch bool) {
  1284  	start := utils.Now()
  1285  	e.preExec(isLastBatch, start)
  1286  
  1287  	e.filter()
  1288  
  1289  	e.join()
  1290  
  1291  	e.project()
  1292  
  1293  	e.reduce()
  1294  
  1295  	e.postExec(start)
  1296  }
  1297  
  1298  // copyHostToDevice copy vector party slice to device vector party slice
  1299  func copyHostToDevice(vps memCom.HostVectorPartySlice, deviceVPSlice deviceVectorPartySlice, stream unsafe.Pointer, device int) (bytesCopied, numTransfers int) {
  1300  	if vps.ValueBytes > 0 {
  1301  		memutils.AsyncCopyHostToDevice(
  1302  			deviceVPSlice.values.getPointer(), vps.Values, vps.ValueBytes,
  1303  			stream, device)
  1304  		bytesCopied += vps.ValueBytes
  1305  		numTransfers++
  1306  	}
  1307  	if vps.NullBytes > 0 {
  1308  		memutils.AsyncCopyHostToDevice(
  1309  			deviceVPSlice.nulls.getPointer(), vps.Nulls, vps.NullBytes,
  1310  			stream, device)
  1311  		bytesCopied += vps.NullBytes
  1312  		numTransfers++
  1313  	}
  1314  	if vps.CountBytes > 0 {
  1315  		memutils.AsyncCopyHostToDevice(
  1316  			deviceVPSlice.counts.getPointer(), vps.Counts, vps.CountBytes,
  1317  			stream, device)
  1318  		bytesCopied += vps.CountBytes
  1319  		numTransfers++
  1320  	}
  1321  	return
  1322  }
  1323  
  1324  func hostToDeviceColumn(hostColumn memCom.HostVectorPartySlice, device int) deviceVectorPartySlice {
  1325  	deviceColumn := deviceVectorPartySlice{
  1326  		length:          hostColumn.Length,
  1327  		valueType:       hostColumn.ValueType,
  1328  		defaultValue:    hostColumn.DefaultValue,
  1329  		valueStartIndex: hostColumn.ValueStartIndex,
  1330  		nullStartIndex:  hostColumn.NullStartIndex,
  1331  		countStartIndex: hostColumn.CountStartIndex,
  1332  	}
  1333  	totalColumnBytes := hostColumn.ValueBytes + hostColumn.NullBytes + hostColumn.CountBytes
  1334  
  1335  	if totalColumnBytes > 0 {
  1336  		deviceColumn.basePtr = deviceAllocate(totalColumnBytes, device)
  1337  		if hostColumn.Counts != nil {
  1338  			deviceColumn.counts = deviceColumn.basePtr.offset(0)
  1339  		}
  1340  
  1341  		if hostColumn.Nulls != nil {
  1342  			deviceColumn.nulls = deviceColumn.basePtr.offset(hostColumn.CountBytes)
  1343  		}
  1344  
  1345  		deviceColumn.values = deviceColumn.basePtr.offset(
  1346  			hostColumn.NullBytes + hostColumn.CountBytes)
  1347  	}
  1348  	return deviceColumn
  1349  }
  1350  
  1351  // shouldSkipLiveBatch will determine whether we can skip processing a live batch by checking time filter and
  1352  // eligible main table common filters. The batch must be non nil.
  1353  func (qc *AQLQueryContext) shouldSkipLiveBatch(b *memstore.LiveBatch) bool {
  1354  	candidatesFilters := []expr.Expr{qc.OOPK.TimeFilters[0], qc.OOPK.TimeFilters[1]}
  1355  	candidatesFilters = append(candidatesFilters, qc.OOPK.MainTableCommonFilters...)
  1356  	for _, filter := range candidatesFilters {
  1357  		if shouldSkipLiveBatchWithFilter(b, filter) {
  1358  			return true
  1359  		}
  1360  	}
  1361  	return false
  1362  }
  1363  
  1364  // shouldSkipLiveBatchWithFilter will check max and min for the corresponding column against the filter express and
  1365  // determines whether we should skip processing this live batch.
  1366  // Following constraints apply:
  1367  //  1. Filter must be on main table.
  1368  //  2. Filter must be a binary expression.
  1369  //  3. OPs must be one of (EQ, GTE,GE,LTE,LE).
  1370  //  4. One side of the expr must be VarRef
  1371  //  5. Another side of the xpr must be NumericalLiteral
  1372  //  6. ColumnType must be UInt32
  1373  func shouldSkipLiveBatchWithFilter(b *memstore.LiveBatch, filter expr.Expr) bool {
  1374  	if filter == nil {
  1375  		return false
  1376  	}
  1377  
  1378  	if binExpr, ok := filter.(*expr.BinaryExpr); ok {
  1379  		var columnExpr *expr.VarRef
  1380  		var numExpr *expr.NumberLiteral
  1381  		op := binExpr.Op
  1382  		switch op {
  1383  		case expr.GTE, expr.GT, expr.LT, expr.LTE, expr.EQ:
  1384  		default:
  1385  			return false
  1386  		}
  1387  		// First try lhs VarRef, rhs Num.
  1388  		lhsVarRef, lhsOK := binExpr.LHS.(*expr.VarRef)
  1389  		rhsNum, rhsOK := binExpr.RHS.(*expr.NumberLiteral)
  1390  		if lhsOK && rhsOK {
  1391  			columnExpr = lhsVarRef
  1392  			numExpr = rhsNum
  1393  		} else {
  1394  			// Then try rhs VarRef, lhs Num.
  1395  			lhsNum, lhsOK := binExpr.LHS.(*expr.NumberLiteral)
  1396  			rhsVarRef, rhsOK := binExpr.RHS.(*expr.VarRef)
  1397  			if lhsOK && rhsOK {
  1398  				// Swap column to the left and number to right.
  1399  				columnExpr = rhsVarRef
  1400  				numExpr = lhsNum
  1401  				// Invert the OP.
  1402  				switch op {
  1403  				case expr.GTE:
  1404  					op = expr.LTE
  1405  				case expr.GT:
  1406  					op = expr.LT
  1407  				case expr.LTE:
  1408  					op = expr.GTE
  1409  				case expr.LT:
  1410  					op = expr.GT
  1411  				}
  1412  			}
  1413  		}
  1414  
  1415  		if columnExpr != nil && numExpr != nil {
  1416  			// Time filters and main table filters are guaranteed to be on main table.
  1417  			vp := b.Columns[columnExpr.ColumnID]
  1418  			if vp == nil {
  1419  				return true
  1420  			}
  1421  
  1422  			if columnExpr.DataType != memCom.Uint32 {
  1423  				return false
  1424  			}
  1425  
  1426  			num := int64(numExpr.Int)
  1427  			minUint32, maxUint32 := vp.(memCom.LiveVectorParty).GetMinMaxValue()
  1428  			min, max := int64(minUint32), int64(maxUint32)
  1429  			switch op {
  1430  			case expr.GTE:
  1431  				return max < num
  1432  			case expr.GT:
  1433  				return max <= num
  1434  			case expr.LTE:
  1435  				return min > num
  1436  			case expr.LT:
  1437  				return min >= num
  1438  			case expr.EQ:
  1439  				return min > num || max < num
  1440  			}
  1441  		}
  1442  	}
  1443  	return false
  1444  }