github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/colexec/hash_aggregator.go (about) 1 // Copyright 2019 The Cockroach Authors. 2 // 3 // Use of this software is governed by the Business Source License 4 // included in the file licenses/BSL.txt. 5 // 6 // As of the Change Date specified in that file, in accordance with 7 // the Business Source License, use of this software will be governed 8 // by the Apache License, Version 2.0, included in the file 9 // licenses/APL.txt. 10 11 package colexec 12 13 import ( 14 "context" 15 "unsafe" 16 17 "github.com/cockroachdb/cockroach/pkg/col/coldata" 18 "github.com/cockroachdb/cockroach/pkg/col/typeconv" 19 "github.com/cockroachdb/cockroach/pkg/sql/colexecbase" 20 "github.com/cockroachdb/cockroach/pkg/sql/colexecbase/colexecerror" 21 "github.com/cockroachdb/cockroach/pkg/sql/colmem" 22 "github.com/cockroachdb/cockroach/pkg/sql/execinfrapb" 23 "github.com/cockroachdb/cockroach/pkg/sql/sqlbase" 24 "github.com/cockroachdb/cockroach/pkg/sql/types" 25 "github.com/cockroachdb/errors" 26 ) 27 28 // hashAggregatorState represents the state of the hash aggregator operator. 29 type hashAggregatorState int 30 31 const ( 32 // hashAggregatorAggregating is the state in which the hashAggregator is 33 // reading the batches from the input and performing aggregation on them, 34 // one at a time. After the input has been fully exhausted, hashAggregator 35 // transitions to hashAggregatorOutputting state. 36 hashAggregatorAggregating hashAggregatorState = iota 37 38 // hashAggregatorOutputting is the state in which the hashAggregator is 39 // writing its aggregation results to output buffer after. 40 hashAggregatorOutputting 41 42 // hashAggregatorDone is the state in which the hashAggregator has finished 43 // writing to the output buffer. 44 hashAggregatorDone 45 ) 46 47 // hashAggregator is an operator that performs aggregation based on specified 48 // grouping columns. This operator performs aggregation in online fashion. It 49 // reads the input one batch at a time, hashes each tuple from the batch and 50 // groups the tuples with same hash code into same group. Then aggregation 51 // function is lazily created for each group. The tuples in that group will be 52 // then passed into the aggregation function. After the input is exhausted, the 53 // operator begins to write the result into an output buffer. The output row 54 // ordering of this operator is arbitrary. 55 type hashAggregator struct { 56 OneInputNode 57 58 allocator *colmem.Allocator 59 60 aggCols [][]uint32 61 aggTypes [][]*types.T 62 aggFuncs []execinfrapb.AggregatorSpec_Func 63 64 inputTypes []*types.T 65 outputTypes []*types.T 66 67 // aggFuncMap stores the mapping from hash code to a vector of aggregation 68 // functions. Each aggregation function is stored along with keys that 69 // corresponds to the group the aggregation function operates on. This is to 70 // handle hash collisions. 71 aggFuncMap hashAggFuncMap 72 73 // state stores the current state of hashAggregator. 74 state hashAggregatorState 75 76 scratch struct { 77 // sels stores the intermediate selection vector for each hash code. It 78 // is maintained in such a way that when for a particular hashCode 79 // there are no tuples in the batch, the corresponding int slice is of 80 // length 0. Also, onlineAgg() method will reset all modified slices to 81 // have zero length once it is done processing all tuples in the batch, 82 // this allows us to not reset the slices for all possible hash codes. 83 // 84 // Instead of having a map from hashCode to []int (which could result 85 // in having many int slices), we are using a constant number of such 86 // slices and have a "map" from hashCode to a "slot" in sels that does 87 // the "translation." The key insight here is that we will have at most 88 // coldata.BatchSize() different hashCodes at once. 89 sels [][]int 90 // hashCodeForSelsSlot stores the hashCode that corresponds to a slot 91 // in sels slice. For example, if we have tuples with the following 92 // hashCodes = {0, 2, 0, 0, 1, 2, 1}, then we will have: 93 // hashCodeForSelsSlot = {0, 2, 1} 94 // sels[0] = {0, 2, 3} 95 // sels[1] = {1, 5} 96 // sels[2] = {4, 6} 97 // Note that we're not using Golang's map for this purpose because 98 // although, in theory, it has O(1) amortized lookup cost, in practice, 99 // it is faster to do a linear search for a particular hashCode in this 100 // slice: given that we have at most coldata.BatchSize() number of 101 // different hashCodes - which is a constant - we get 102 // O(coldata.BatchSize()) = O(1) lookup cost. And now we have two cases 103 // to consider: 104 // 1. we have few distinct hashCodes (large group sizes), then the 105 // overhead of linear search will be significantly lower than of a 106 // lookup in map 107 // 2. we have many distinct hashCodes (small group sizes), then the 108 // map *might* outperform the linear search, but the time spent in 109 // other parts of the hash aggregator will dominate the total runtime, 110 // so this would not matter. 111 hashCodeForSelsSlot []uint64 112 113 // group is a boolean vector where "true" represent the beginning of a group 114 // in the column. It is shared among all aggregation functions. Since 115 // hashAggregator manually manages mapping between input groups and their 116 // corresponding aggregation functions, group is set to all false to prevent 117 // premature materialization of aggregation result in the aggregation 118 // function. However, aggregation function expects at least one group in its 119 // input batches, (that is, at least one "true" in the group vector 120 // corresponding to the selection vector). Therefore, before the first 121 // invocation of .Compute() method, the element in group vector which 122 // corresponds to the first value of the selection vector is set to true so 123 // that aggregation function will initialize properly. Then after .Compute() 124 // finishes, it is set back to false so the same group vector can be reused 125 // by other aggregation functions. 126 group []bool 127 } 128 129 // keyMapping stores the key values for each aggregation group. It is a 130 // bufferedBatch because in the worst case where all keys in the grouping 131 // columns are distinct, we need to store every single key in the input. 132 keyMapping *appendOnlyBufferedBatch 133 134 output struct { 135 coldata.Batch 136 137 // pendingOutput indicates if there is more data that needs to be returned. 138 pendingOutput bool 139 140 // resumeHashCode is the hash code that hashAggregator should start reading 141 // from on the next iteration of Next(). 142 resumeHashCode uint64 143 144 // resumeIdx is the index of the vector corresponding to the resumeHashCode 145 // that hashAggregator should start reading from on the next iteration of Next(). 146 resumeIdx int 147 } 148 149 testingKnobs struct { 150 // numOfHashBuckets is the number of hash buckets that each tuple will be 151 // assigned to. When it is 0, hash aggregator will not enforce maximum 152 // number of hash buckets. It is used to test hash collision. 153 numOfHashBuckets uint64 154 } 155 156 // groupCols stores the indices of the grouping columns. 157 groupCols []uint32 158 159 // groupTypes stores the types of the grouping columns. 160 groupTypes []*types.T 161 groupCanonicalTypeFamilies []types.Family 162 163 // hashBuffer stores hash values for each tuple in the buffered batch. 164 hashBuffer []uint64 165 166 aggFnsAlloc *aggregateFuncsAlloc 167 hashAlloc hashAggFuncsAlloc 168 cancelChecker CancelChecker 169 overloadHelper overloadHelper 170 datumAlloc sqlbase.DatumAlloc 171 } 172 173 var _ colexecbase.Operator = &hashAggregator{} 174 175 // hashAggregatorAllocSize determines the allocation size used by the hash 176 // aggregator's allocators. This number was chosen after running benchmarks of 177 // 'sum' aggregation on ints and decimals with varying group sizes (powers of 2 178 // from 1 to 4096). 179 const hashAggregatorAllocSize = 64 180 181 // NewHashAggregator creates a hash aggregator on the given grouping columns. 182 // The input specifications to this function are the same as that of the 183 // NewOrderedAggregator function. 184 func NewHashAggregator( 185 allocator *colmem.Allocator, 186 input colexecbase.Operator, 187 typs []*types.T, 188 aggFns []execinfrapb.AggregatorSpec_Func, 189 groupCols []uint32, 190 aggCols [][]uint32, 191 ) (colexecbase.Operator, error) { 192 aggTyps := extractAggTypes(aggCols, typs) 193 outputTypes, err := makeAggregateFuncsOutputTypes(aggTyps, aggFns) 194 if err != nil { 195 return nil, errors.AssertionFailedf( 196 "this error should have been checked in isAggregateSupported\n%+v", err, 197 ) 198 } 199 200 groupTypes := make([]*types.T, len(groupCols)) 201 for i, colIdx := range groupCols { 202 groupTypes[i] = typs[colIdx] 203 } 204 205 aggFnsAlloc, err := newAggregateFuncsAlloc(allocator, aggTyps, aggFns, hashAggregatorAllocSize) 206 207 return &hashAggregator{ 208 OneInputNode: NewOneInputNode(input), 209 allocator: allocator, 210 211 aggCols: aggCols, 212 aggFuncs: aggFns, 213 aggTypes: aggTyps, 214 aggFuncMap: make(hashAggFuncMap), 215 216 state: hashAggregatorAggregating, 217 inputTypes: typs, 218 outputTypes: outputTypes, 219 220 groupCols: groupCols, 221 groupTypes: groupTypes, 222 groupCanonicalTypeFamilies: typeconv.ToCanonicalTypeFamilies(groupTypes), 223 224 aggFnsAlloc: aggFnsAlloc, 225 hashAlloc: hashAggFuncsAlloc{allocator: allocator}, 226 }, err 227 } 228 229 func (op *hashAggregator) Init() { 230 op.input.Init() 231 op.output.Batch = op.allocator.NewMemBatch(op.outputTypes) 232 233 op.scratch.sels = make([][]int, coldata.BatchSize()) 234 op.scratch.hashCodeForSelsSlot = make([]uint64, coldata.BatchSize()) 235 op.scratch.group = make([]bool, coldata.BatchSize()) 236 // Eventually, op.keyMapping will contain as many tuples as there are 237 // groups in the input, but we don't know that number upfront, so we 238 // allocate it with some reasonably sized constant capacity. 239 op.keyMapping = newAppendOnlyBufferedBatch( 240 op.allocator, op.groupTypes, coldata.BatchSize(), 241 ) 242 op.hashBuffer = make([]uint64, coldata.BatchSize()) 243 } 244 245 func (op *hashAggregator) Next(ctx context.Context) coldata.Batch { 246 for { 247 switch op.state { 248 case hashAggregatorAggregating: 249 b := op.input.Next(ctx) 250 if b.Length() == 0 { 251 op.state = hashAggregatorOutputting 252 continue 253 } 254 op.buildSelectionForEachHashCode(ctx, b) 255 op.onlineAgg(b) 256 case hashAggregatorOutputting: 257 curOutputIdx := 0 258 op.output.ResetInternalBatch() 259 260 // If there is pending output, we try to finish outputting the aggregation 261 // result in the same bucket. If we cannot finish, we update resumeIdx and 262 // return the current batch. 263 if op.output.pendingOutput { 264 remainingAggFuncs := op.aggFuncMap[op.output.resumeHashCode][op.output.resumeIdx:] 265 for groupIdx, aggFunc := range remainingAggFuncs { 266 if curOutputIdx < coldata.BatchSize() { 267 for _, fn := range aggFunc.fns { 268 fn.SetOutputIndex(curOutputIdx) 269 fn.Flush() 270 } 271 } else { 272 op.output.resumeIdx = op.output.resumeIdx + groupIdx 273 op.output.SetLength(curOutputIdx) 274 275 return op.output 276 } 277 curOutputIdx++ 278 } 279 delete(op.aggFuncMap, op.output.resumeHashCode) 280 } 281 282 op.output.pendingOutput = false 283 284 for aggHashCode, aggFuncs := range op.aggFuncMap { 285 for groupIdx, aggFunc := range aggFuncs { 286 if curOutputIdx < coldata.BatchSize() { 287 for _, fn := range aggFunc.fns { 288 fn.SetOutputIndex(curOutputIdx) 289 fn.Flush() 290 } 291 } else { 292 // If current batch is filled, we record where we left off 293 // and then return the current batch. 294 op.output.resumeIdx = groupIdx 295 op.output.resumeHashCode = aggHashCode 296 op.output.pendingOutput = true 297 op.output.SetLength(curOutputIdx) 298 299 return op.output 300 } 301 curOutputIdx++ 302 } 303 delete(op.aggFuncMap, aggHashCode) 304 } 305 306 op.state = hashAggregatorDone 307 op.output.SetLength(curOutputIdx) 308 return op.output 309 case hashAggregatorDone: 310 return coldata.ZeroBatch 311 default: 312 colexecerror.InternalError("hash aggregator in unhandled state") 313 // This code is unreachable, but the compiler cannot infer that. 314 return nil 315 } 316 } 317 } 318 319 func (op *hashAggregator) buildSelectionForEachHashCode(ctx context.Context, b coldata.Batch) { 320 nKeys := b.Length() 321 hashBuffer := op.hashBuffer[:nKeys] 322 323 initHash(hashBuffer, nKeys, defaultInitHashValue) 324 325 for _, colIdx := range op.groupCols { 326 rehash(ctx, 327 hashBuffer, 328 b.ColVec(int(colIdx)), 329 nKeys, 330 b.Selection(), 331 op.cancelChecker, 332 op.overloadHelper, 333 &op.datumAlloc, 334 ) 335 } 336 337 if op.testingKnobs.numOfHashBuckets != 0 { 338 finalizeHash(hashBuffer, nKeys, op.testingKnobs.numOfHashBuckets) 339 } 340 341 op.populateSels(b, hashBuffer) 342 } 343 344 // onlineAgg probes aggFuncMap using the built sels map and lazily creates 345 // aggFunctions for each group if it doesn't not exist. Then it calls Compute() 346 // on each aggregation function to perform aggregation. 347 func (op *hashAggregator) onlineAgg(b coldata.Batch) { 348 for selsSlot, hashCode := range op.scratch.hashCodeForSelsSlot { 349 remaining := op.scratch.sels[selsSlot] 350 351 var anyMatched bool 352 353 // Stage 1: Probe aggregate functions for each hash code and perform 354 // aggregation. 355 if aggFuncs, ok := op.aggFuncMap[hashCode]; ok { 356 for _, aggFunc := range aggFuncs { 357 // We write the selection vector of matched tuples directly 358 // into the selection vector of b and selection vector of 359 // unmatched tuples into 'remaining'.'remaining' will reuse the 360 // underlying memory allocated for 'sel' to avoid extra 361 // allocation and copying. 362 anyMatched, remaining = aggFunc.match( 363 remaining, b, op.groupCols, op.groupTypes, 364 op.groupCanonicalTypeFamilies, op.keyMapping, 365 op.scratch.group[:len(remaining)], false, /* firstDefiniteMatch */ 366 ) 367 if anyMatched { 368 aggFunc.compute(b, op.aggCols) 369 } 370 } 371 } else { 372 // No aggregate functions exist for this hashCode, create one. 373 op.aggFuncMap[hashCode] = op.hashAlloc.newHashAggFuncsSlice() 374 } 375 376 // Stage 2: Build aggregate function that doesn't exist, then perform 377 // aggregation on the newly created aggregate function. 378 for len(remaining) > 0 { 379 // Record the selection vector index of the beginning of the group. 380 groupStartIdx := remaining[0] 381 382 // Build new agg functions. 383 keyIdx := op.keyMapping.Length() 384 aggFunc := op.hashAlloc.newHashAggFuncs() 385 aggFunc.keyIdx = keyIdx 386 387 // Store the key of the current aggregating group into keyMapping. 388 op.allocator.PerformOperation(op.keyMapping.ColVecs(), func() { 389 for keyIdx, colIdx := range op.groupCols { 390 // TODO(azhng): Try to preallocate enough memory so instead of 391 // .Append() we can use execgen.SET to improve the 392 // performance. 393 op.keyMapping.ColVec(keyIdx).Append(coldata.SliceArgs{ 394 Src: b.ColVec(int(colIdx)), 395 DestIdx: aggFunc.keyIdx, 396 SrcStartIdx: groupStartIdx, 397 SrcEndIdx: groupStartIdx + 1, 398 }) 399 } 400 op.keyMapping.SetLength(keyIdx + 1) 401 }) 402 403 aggFunc.fns = op.aggFnsAlloc.makeAggregateFuncs() 404 op.aggFuncMap[hashCode] = append(op.aggFuncMap[hashCode], aggFunc) 405 406 // Select rest of the tuples that matches the current key. We don't need 407 // to check if there is any match since 'remaining[0]' will always be 408 // matched. 409 _, remaining = aggFunc.match( 410 remaining, b, op.groupCols, op.groupTypes, 411 op.groupCanonicalTypeFamilies, op.keyMapping, 412 op.scratch.group[:len(remaining)], true, /* firstDefiniteMatch */ 413 ) 414 415 // Hack required to get aggregation function working. See '.scratch.group' 416 // field comment in hashAggregator for more details. 417 op.scratch.group[groupStartIdx] = true 418 aggFunc.init(op.scratch.group, op.output.Batch) 419 aggFunc.compute(b, op.aggCols) 420 op.scratch.group[groupStartIdx] = false 421 } 422 423 // We have processed all tuples with this hashCode, so we should reset 424 // the length of the corresponding slice. 425 op.scratch.sels[selsSlot] = op.scratch.sels[selsSlot][:0] 426 } 427 } 428 429 // reset resets the hashAggregator for another run. Primarily used for 430 // benchmarks. 431 func (op *hashAggregator) reset(ctx context.Context) { 432 if r, ok := op.input.(resetter); ok { 433 r.reset(ctx) 434 } 435 436 op.aggFuncMap = hashAggFuncMap{} 437 op.state = hashAggregatorAggregating 438 439 op.output.ResetInternalBatch() 440 op.output.SetLength(0) 441 op.output.pendingOutput = false 442 443 op.keyMapping.ResetInternalBatch() 444 op.keyMapping.SetLength(0) 445 } 446 447 // hashAggFuncs stores the aggregation functions for the corresponding 448 // aggregating group. 449 type hashAggFuncs struct { 450 // keyIdx is the index of key of the current aggregating group, which is 451 // stored in the hashAggregator keyMapping batch. 452 keyIdx int 453 454 fns []aggregateFunc 455 } 456 457 const ( 458 sizeOfHashAggFuncs = unsafe.Sizeof(hashAggFuncs{}) 459 sizeOfHashAggFuncsPtr = unsafe.Sizeof(&hashAggFuncs{}) 460 ) 461 462 // TODO(yuzefovich): we need to account for memory used by this map. It is 463 // likely that we will replace Golang's map with our vectorized hash table, so 464 // we might hold off with fixing the accounting until then. 465 type hashAggFuncMap map[uint64][]*hashAggFuncs 466 467 func (v *hashAggFuncs) init(group []bool, b coldata.Batch) { 468 for fnIdx, fn := range v.fns { 469 fn.Init(group, b.ColVec(fnIdx)) 470 } 471 } 472 473 func (v *hashAggFuncs) compute(b coldata.Batch, aggCols [][]uint32) { 474 for fnIdx, fn := range v.fns { 475 fn.Compute(b, aggCols[fnIdx]) 476 } 477 } 478 479 // hashAggFuncsAlloc is a utility struct that batches allocations of 480 // hashAggFuncs and slices of pointers to hashAggFuncs. 481 type hashAggFuncsAlloc struct { 482 allocator *colmem.Allocator 483 buf []hashAggFuncs 484 ptrBuf []*hashAggFuncs 485 } 486 487 func (a *hashAggFuncsAlloc) newHashAggFuncs() *hashAggFuncs { 488 if len(a.buf) == 0 { 489 a.allocator.AdjustMemoryUsage(int64(hashAggregatorAllocSize * sizeOfHashAggFuncs)) 490 a.buf = make([]hashAggFuncs, hashAggregatorAllocSize) 491 } 492 ret := &a.buf[0] 493 a.buf = a.buf[1:] 494 return ret 495 } 496 497 func (a *hashAggFuncsAlloc) newHashAggFuncsSlice() []*hashAggFuncs { 498 if len(a.ptrBuf) == 0 { 499 a.allocator.AdjustMemoryUsage(int64(hashAggregatorAllocSize * sizeOfHashAggFuncsPtr)) 500 a.ptrBuf = make([]*hashAggFuncs, hashAggregatorAllocSize) 501 } 502 // Since we don't expect a lot of hash collisions we only give out small 503 // amount of memory here. 504 ret := a.ptrBuf[0:0:1] 505 a.ptrBuf = a.ptrBuf[1:] 506 return ret 507 }