github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/colexec/hash_aggregator.go (about)

     1  // Copyright 2019 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package colexec
    12  
    13  import (
    14  	"context"
    15  	"unsafe"
    16  
    17  	"github.com/cockroachdb/cockroach/pkg/col/coldata"
    18  	"github.com/cockroachdb/cockroach/pkg/col/typeconv"
    19  	"github.com/cockroachdb/cockroach/pkg/sql/colexecbase"
    20  	"github.com/cockroachdb/cockroach/pkg/sql/colexecbase/colexecerror"
    21  	"github.com/cockroachdb/cockroach/pkg/sql/colmem"
    22  	"github.com/cockroachdb/cockroach/pkg/sql/execinfrapb"
    23  	"github.com/cockroachdb/cockroach/pkg/sql/sqlbase"
    24  	"github.com/cockroachdb/cockroach/pkg/sql/types"
    25  	"github.com/cockroachdb/errors"
    26  )
    27  
    28  // hashAggregatorState represents the state of the hash aggregator operator.
    29  type hashAggregatorState int
    30  
    31  const (
    32  	// hashAggregatorAggregating is the state in which the hashAggregator is
    33  	// reading the batches from the input and performing aggregation on them,
    34  	// one at a time. After the input has been fully exhausted, hashAggregator
    35  	// transitions to hashAggregatorOutputting state.
    36  	hashAggregatorAggregating hashAggregatorState = iota
    37  
    38  	// hashAggregatorOutputting is the state in which the hashAggregator is
    39  	// writing its aggregation results to output buffer after.
    40  	hashAggregatorOutputting
    41  
    42  	// hashAggregatorDone is the state in which the hashAggregator has finished
    43  	// writing to the output buffer.
    44  	hashAggregatorDone
    45  )
    46  
    47  // hashAggregator is an operator that performs aggregation based on specified
    48  // grouping columns. This operator performs aggregation in online fashion. It
    49  // reads the input one batch at a time, hashes each tuple from the batch and
    50  // groups the tuples with same hash code into same group. Then aggregation
    51  // function is lazily created for each group. The tuples in that group will be
    52  // then passed into the aggregation function. After the input is exhausted, the
    53  // operator begins to write the result into an output buffer. The output row
    54  // ordering of this operator is arbitrary.
    55  type hashAggregator struct {
    56  	OneInputNode
    57  
    58  	allocator *colmem.Allocator
    59  
    60  	aggCols  [][]uint32
    61  	aggTypes [][]*types.T
    62  	aggFuncs []execinfrapb.AggregatorSpec_Func
    63  
    64  	inputTypes  []*types.T
    65  	outputTypes []*types.T
    66  
    67  	// aggFuncMap stores the mapping from hash code to a vector of aggregation
    68  	// functions. Each aggregation function is stored along with keys that
    69  	// corresponds to the group the aggregation function operates on. This is to
    70  	// handle hash collisions.
    71  	aggFuncMap hashAggFuncMap
    72  
    73  	// state stores the current state of hashAggregator.
    74  	state hashAggregatorState
    75  
    76  	scratch struct {
    77  		// sels stores the intermediate selection vector for each hash code. It
    78  		// is maintained in such a way that when for a particular hashCode
    79  		// there are no tuples in the batch, the corresponding int slice is of
    80  		// length 0. Also, onlineAgg() method will reset all modified slices to
    81  		// have zero length once it is done processing all tuples in the batch,
    82  		// this allows us to not reset the slices for all possible hash codes.
    83  		//
    84  		// Instead of having a map from hashCode to []int (which could result
    85  		// in having many int slices), we are using a constant number of such
    86  		// slices and have a "map" from hashCode to a "slot" in sels that does
    87  		// the "translation." The key insight here is that we will have at most
    88  		// coldata.BatchSize() different hashCodes at once.
    89  		sels [][]int
    90  		// hashCodeForSelsSlot stores the hashCode that corresponds to a slot
    91  		// in sels slice. For example, if we have tuples with the following
    92  		// hashCodes = {0, 2, 0, 0, 1, 2, 1}, then we will have:
    93  		//   hashCodeForSelsSlot = {0, 2, 1}
    94  		//   sels[0] = {0, 2, 3}
    95  		//   sels[1] = {1, 5}
    96  		//   sels[2] = {4, 6}
    97  		// Note that we're not using Golang's map for this purpose because
    98  		// although, in theory, it has O(1) amortized lookup cost, in practice,
    99  		// it is faster to do a linear search for a particular hashCode in this
   100  		// slice: given that we have at most coldata.BatchSize() number of
   101  		// different hashCodes - which is a constant - we get
   102  		// O(coldata.BatchSize()) = O(1) lookup cost. And now we have two cases
   103  		// to consider:
   104  		// 1. we have few distinct hashCodes (large group sizes), then the
   105  		// overhead of linear search will be significantly lower than of a
   106  		// lookup in map
   107  		// 2. we have many distinct hashCodes (small group sizes), then the
   108  		// map *might* outperform the linear search, but the time spent in
   109  		// other parts of the hash aggregator will dominate the total runtime,
   110  		// so this would not matter.
   111  		hashCodeForSelsSlot []uint64
   112  
   113  		// group is a boolean vector where "true" represent the beginning of a group
   114  		// in the column. It is shared among all aggregation functions. Since
   115  		// hashAggregator manually manages mapping between input groups and their
   116  		// corresponding aggregation functions, group is set to all false to prevent
   117  		// premature materialization of aggregation result in the aggregation
   118  		// function. However, aggregation function expects at least one group in its
   119  		// input batches, (that is, at least one "true" in the group vector
   120  		// corresponding to the selection vector). Therefore, before the first
   121  		// invocation of .Compute() method, the element in group vector which
   122  		// corresponds to the first value of the selection vector is set to true so
   123  		// that aggregation function will initialize properly. Then after .Compute()
   124  		// finishes, it is set back to false so the same group vector can be reused
   125  		// by other aggregation functions.
   126  		group []bool
   127  	}
   128  
   129  	// keyMapping stores the key values for each aggregation group. It is a
   130  	// bufferedBatch because in the worst case where all keys in the grouping
   131  	// columns are distinct, we need to store every single key in the input.
   132  	keyMapping *appendOnlyBufferedBatch
   133  
   134  	output struct {
   135  		coldata.Batch
   136  
   137  		// pendingOutput indicates if there is more data that needs to be returned.
   138  		pendingOutput bool
   139  
   140  		// resumeHashCode is the hash code that hashAggregator should start reading
   141  		// from on the next iteration of Next().
   142  		resumeHashCode uint64
   143  
   144  		// resumeIdx is the index of the vector corresponding to the resumeHashCode
   145  		// that hashAggregator should start reading from on the next iteration of Next().
   146  		resumeIdx int
   147  	}
   148  
   149  	testingKnobs struct {
   150  		// numOfHashBuckets is the number of hash buckets that each tuple will be
   151  		// assigned to. When it is 0, hash aggregator will not enforce maximum
   152  		// number of hash buckets. It is used to test hash collision.
   153  		numOfHashBuckets uint64
   154  	}
   155  
   156  	// groupCols stores the indices of the grouping columns.
   157  	groupCols []uint32
   158  
   159  	// groupTypes stores the types of the grouping columns.
   160  	groupTypes                 []*types.T
   161  	groupCanonicalTypeFamilies []types.Family
   162  
   163  	// hashBuffer stores hash values for each tuple in the buffered batch.
   164  	hashBuffer []uint64
   165  
   166  	aggFnsAlloc    *aggregateFuncsAlloc
   167  	hashAlloc      hashAggFuncsAlloc
   168  	cancelChecker  CancelChecker
   169  	overloadHelper overloadHelper
   170  	datumAlloc     sqlbase.DatumAlloc
   171  }
   172  
   173  var _ colexecbase.Operator = &hashAggregator{}
   174  
   175  // hashAggregatorAllocSize determines the allocation size used by the hash
   176  // aggregator's allocators. This number was chosen after running benchmarks of
   177  // 'sum' aggregation on ints and decimals with varying group sizes (powers of 2
   178  // from 1 to 4096).
   179  const hashAggregatorAllocSize = 64
   180  
   181  // NewHashAggregator creates a hash aggregator on the given grouping columns.
   182  // The input specifications to this function are the same as that of the
   183  // NewOrderedAggregator function.
   184  func NewHashAggregator(
   185  	allocator *colmem.Allocator,
   186  	input colexecbase.Operator,
   187  	typs []*types.T,
   188  	aggFns []execinfrapb.AggregatorSpec_Func,
   189  	groupCols []uint32,
   190  	aggCols [][]uint32,
   191  ) (colexecbase.Operator, error) {
   192  	aggTyps := extractAggTypes(aggCols, typs)
   193  	outputTypes, err := makeAggregateFuncsOutputTypes(aggTyps, aggFns)
   194  	if err != nil {
   195  		return nil, errors.AssertionFailedf(
   196  			"this error should have been checked in isAggregateSupported\n%+v", err,
   197  		)
   198  	}
   199  
   200  	groupTypes := make([]*types.T, len(groupCols))
   201  	for i, colIdx := range groupCols {
   202  		groupTypes[i] = typs[colIdx]
   203  	}
   204  
   205  	aggFnsAlloc, err := newAggregateFuncsAlloc(allocator, aggTyps, aggFns, hashAggregatorAllocSize)
   206  
   207  	return &hashAggregator{
   208  		OneInputNode: NewOneInputNode(input),
   209  		allocator:    allocator,
   210  
   211  		aggCols:    aggCols,
   212  		aggFuncs:   aggFns,
   213  		aggTypes:   aggTyps,
   214  		aggFuncMap: make(hashAggFuncMap),
   215  
   216  		state:       hashAggregatorAggregating,
   217  		inputTypes:  typs,
   218  		outputTypes: outputTypes,
   219  
   220  		groupCols:                  groupCols,
   221  		groupTypes:                 groupTypes,
   222  		groupCanonicalTypeFamilies: typeconv.ToCanonicalTypeFamilies(groupTypes),
   223  
   224  		aggFnsAlloc: aggFnsAlloc,
   225  		hashAlloc:   hashAggFuncsAlloc{allocator: allocator},
   226  	}, err
   227  }
   228  
   229  func (op *hashAggregator) Init() {
   230  	op.input.Init()
   231  	op.output.Batch = op.allocator.NewMemBatch(op.outputTypes)
   232  
   233  	op.scratch.sels = make([][]int, coldata.BatchSize())
   234  	op.scratch.hashCodeForSelsSlot = make([]uint64, coldata.BatchSize())
   235  	op.scratch.group = make([]bool, coldata.BatchSize())
   236  	// Eventually, op.keyMapping will contain as many tuples as there are
   237  	// groups in the input, but we don't know that number upfront, so we
   238  	// allocate it with some reasonably sized constant capacity.
   239  	op.keyMapping = newAppendOnlyBufferedBatch(
   240  		op.allocator, op.groupTypes, coldata.BatchSize(),
   241  	)
   242  	op.hashBuffer = make([]uint64, coldata.BatchSize())
   243  }
   244  
   245  func (op *hashAggregator) Next(ctx context.Context) coldata.Batch {
   246  	for {
   247  		switch op.state {
   248  		case hashAggregatorAggregating:
   249  			b := op.input.Next(ctx)
   250  			if b.Length() == 0 {
   251  				op.state = hashAggregatorOutputting
   252  				continue
   253  			}
   254  			op.buildSelectionForEachHashCode(ctx, b)
   255  			op.onlineAgg(b)
   256  		case hashAggregatorOutputting:
   257  			curOutputIdx := 0
   258  			op.output.ResetInternalBatch()
   259  
   260  			// If there is pending output, we try to finish outputting the aggregation
   261  			// result in the same bucket. If we cannot finish, we update resumeIdx and
   262  			// return the current batch.
   263  			if op.output.pendingOutput {
   264  				remainingAggFuncs := op.aggFuncMap[op.output.resumeHashCode][op.output.resumeIdx:]
   265  				for groupIdx, aggFunc := range remainingAggFuncs {
   266  					if curOutputIdx < coldata.BatchSize() {
   267  						for _, fn := range aggFunc.fns {
   268  							fn.SetOutputIndex(curOutputIdx)
   269  							fn.Flush()
   270  						}
   271  					} else {
   272  						op.output.resumeIdx = op.output.resumeIdx + groupIdx
   273  						op.output.SetLength(curOutputIdx)
   274  
   275  						return op.output
   276  					}
   277  					curOutputIdx++
   278  				}
   279  				delete(op.aggFuncMap, op.output.resumeHashCode)
   280  			}
   281  
   282  			op.output.pendingOutput = false
   283  
   284  			for aggHashCode, aggFuncs := range op.aggFuncMap {
   285  				for groupIdx, aggFunc := range aggFuncs {
   286  					if curOutputIdx < coldata.BatchSize() {
   287  						for _, fn := range aggFunc.fns {
   288  							fn.SetOutputIndex(curOutputIdx)
   289  							fn.Flush()
   290  						}
   291  					} else {
   292  						// If current batch is filled, we record where we left off
   293  						// and then return the current batch.
   294  						op.output.resumeIdx = groupIdx
   295  						op.output.resumeHashCode = aggHashCode
   296  						op.output.pendingOutput = true
   297  						op.output.SetLength(curOutputIdx)
   298  
   299  						return op.output
   300  					}
   301  					curOutputIdx++
   302  				}
   303  				delete(op.aggFuncMap, aggHashCode)
   304  			}
   305  
   306  			op.state = hashAggregatorDone
   307  			op.output.SetLength(curOutputIdx)
   308  			return op.output
   309  		case hashAggregatorDone:
   310  			return coldata.ZeroBatch
   311  		default:
   312  			colexecerror.InternalError("hash aggregator in unhandled state")
   313  			// This code is unreachable, but the compiler cannot infer that.
   314  			return nil
   315  		}
   316  	}
   317  }
   318  
   319  func (op *hashAggregator) buildSelectionForEachHashCode(ctx context.Context, b coldata.Batch) {
   320  	nKeys := b.Length()
   321  	hashBuffer := op.hashBuffer[:nKeys]
   322  
   323  	initHash(hashBuffer, nKeys, defaultInitHashValue)
   324  
   325  	for _, colIdx := range op.groupCols {
   326  		rehash(ctx,
   327  			hashBuffer,
   328  			b.ColVec(int(colIdx)),
   329  			nKeys,
   330  			b.Selection(),
   331  			op.cancelChecker,
   332  			op.overloadHelper,
   333  			&op.datumAlloc,
   334  		)
   335  	}
   336  
   337  	if op.testingKnobs.numOfHashBuckets != 0 {
   338  		finalizeHash(hashBuffer, nKeys, op.testingKnobs.numOfHashBuckets)
   339  	}
   340  
   341  	op.populateSels(b, hashBuffer)
   342  }
   343  
   344  // onlineAgg probes aggFuncMap using the built sels map and lazily creates
   345  // aggFunctions for each group if it doesn't not exist. Then it calls Compute()
   346  // on each aggregation function to perform aggregation.
   347  func (op *hashAggregator) onlineAgg(b coldata.Batch) {
   348  	for selsSlot, hashCode := range op.scratch.hashCodeForSelsSlot {
   349  		remaining := op.scratch.sels[selsSlot]
   350  
   351  		var anyMatched bool
   352  
   353  		// Stage 1: Probe aggregate functions for each hash code and perform
   354  		//          aggregation.
   355  		if aggFuncs, ok := op.aggFuncMap[hashCode]; ok {
   356  			for _, aggFunc := range aggFuncs {
   357  				// We write the selection vector of matched tuples directly
   358  				// into the selection vector of b and selection vector of
   359  				// unmatched tuples into 'remaining'.'remaining' will reuse the
   360  				// underlying memory allocated for 'sel' to avoid extra
   361  				// allocation and copying.
   362  				anyMatched, remaining = aggFunc.match(
   363  					remaining, b, op.groupCols, op.groupTypes,
   364  					op.groupCanonicalTypeFamilies, op.keyMapping,
   365  					op.scratch.group[:len(remaining)], false, /* firstDefiniteMatch */
   366  				)
   367  				if anyMatched {
   368  					aggFunc.compute(b, op.aggCols)
   369  				}
   370  			}
   371  		} else {
   372  			// No aggregate functions exist for this hashCode, create one.
   373  			op.aggFuncMap[hashCode] = op.hashAlloc.newHashAggFuncsSlice()
   374  		}
   375  
   376  		// Stage 2: Build aggregate function that doesn't exist, then perform
   377  		//          aggregation on the newly created aggregate function.
   378  		for len(remaining) > 0 {
   379  			// Record the selection vector index of the beginning of the group.
   380  			groupStartIdx := remaining[0]
   381  
   382  			// Build new agg functions.
   383  			keyIdx := op.keyMapping.Length()
   384  			aggFunc := op.hashAlloc.newHashAggFuncs()
   385  			aggFunc.keyIdx = keyIdx
   386  
   387  			// Store the key of the current aggregating group into keyMapping.
   388  			op.allocator.PerformOperation(op.keyMapping.ColVecs(), func() {
   389  				for keyIdx, colIdx := range op.groupCols {
   390  					// TODO(azhng): Try to preallocate enough memory so instead of
   391  					// .Append() we can use execgen.SET to improve the
   392  					// performance.
   393  					op.keyMapping.ColVec(keyIdx).Append(coldata.SliceArgs{
   394  						Src:         b.ColVec(int(colIdx)),
   395  						DestIdx:     aggFunc.keyIdx,
   396  						SrcStartIdx: groupStartIdx,
   397  						SrcEndIdx:   groupStartIdx + 1,
   398  					})
   399  				}
   400  				op.keyMapping.SetLength(keyIdx + 1)
   401  			})
   402  
   403  			aggFunc.fns = op.aggFnsAlloc.makeAggregateFuncs()
   404  			op.aggFuncMap[hashCode] = append(op.aggFuncMap[hashCode], aggFunc)
   405  
   406  			// Select rest of the tuples that matches the current key. We don't need
   407  			// to check if there is any match since 'remaining[0]' will always be
   408  			// matched.
   409  			_, remaining = aggFunc.match(
   410  				remaining, b, op.groupCols, op.groupTypes,
   411  				op.groupCanonicalTypeFamilies, op.keyMapping,
   412  				op.scratch.group[:len(remaining)], true, /* firstDefiniteMatch */
   413  			)
   414  
   415  			// Hack required to get aggregation function working. See '.scratch.group'
   416  			// field comment in hashAggregator for more details.
   417  			op.scratch.group[groupStartIdx] = true
   418  			aggFunc.init(op.scratch.group, op.output.Batch)
   419  			aggFunc.compute(b, op.aggCols)
   420  			op.scratch.group[groupStartIdx] = false
   421  		}
   422  
   423  		// We have processed all tuples with this hashCode, so we should reset
   424  		// the length of the corresponding slice.
   425  		op.scratch.sels[selsSlot] = op.scratch.sels[selsSlot][:0]
   426  	}
   427  }
   428  
   429  // reset resets the hashAggregator for another run. Primarily used for
   430  // benchmarks.
   431  func (op *hashAggregator) reset(ctx context.Context) {
   432  	if r, ok := op.input.(resetter); ok {
   433  		r.reset(ctx)
   434  	}
   435  
   436  	op.aggFuncMap = hashAggFuncMap{}
   437  	op.state = hashAggregatorAggregating
   438  
   439  	op.output.ResetInternalBatch()
   440  	op.output.SetLength(0)
   441  	op.output.pendingOutput = false
   442  
   443  	op.keyMapping.ResetInternalBatch()
   444  	op.keyMapping.SetLength(0)
   445  }
   446  
   447  // hashAggFuncs stores the aggregation functions for the corresponding
   448  // aggregating group.
   449  type hashAggFuncs struct {
   450  	// keyIdx is the index of key of the current aggregating group, which is
   451  	// stored in the hashAggregator keyMapping batch.
   452  	keyIdx int
   453  
   454  	fns []aggregateFunc
   455  }
   456  
   457  const (
   458  	sizeOfHashAggFuncs    = unsafe.Sizeof(hashAggFuncs{})
   459  	sizeOfHashAggFuncsPtr = unsafe.Sizeof(&hashAggFuncs{})
   460  )
   461  
   462  // TODO(yuzefovich): we need to account for memory used by this map. It is
   463  // likely that we will replace Golang's map with our vectorized hash table, so
   464  // we might hold off with fixing the accounting until then.
   465  type hashAggFuncMap map[uint64][]*hashAggFuncs
   466  
   467  func (v *hashAggFuncs) init(group []bool, b coldata.Batch) {
   468  	for fnIdx, fn := range v.fns {
   469  		fn.Init(group, b.ColVec(fnIdx))
   470  	}
   471  }
   472  
   473  func (v *hashAggFuncs) compute(b coldata.Batch, aggCols [][]uint32) {
   474  	for fnIdx, fn := range v.fns {
   475  		fn.Compute(b, aggCols[fnIdx])
   476  	}
   477  }
   478  
   479  // hashAggFuncsAlloc is a utility struct that batches allocations of
   480  // hashAggFuncs and slices of pointers to hashAggFuncs.
   481  type hashAggFuncsAlloc struct {
   482  	allocator *colmem.Allocator
   483  	buf       []hashAggFuncs
   484  	ptrBuf    []*hashAggFuncs
   485  }
   486  
   487  func (a *hashAggFuncsAlloc) newHashAggFuncs() *hashAggFuncs {
   488  	if len(a.buf) == 0 {
   489  		a.allocator.AdjustMemoryUsage(int64(hashAggregatorAllocSize * sizeOfHashAggFuncs))
   490  		a.buf = make([]hashAggFuncs, hashAggregatorAllocSize)
   491  	}
   492  	ret := &a.buf[0]
   493  	a.buf = a.buf[1:]
   494  	return ret
   495  }
   496  
   497  func (a *hashAggFuncsAlloc) newHashAggFuncsSlice() []*hashAggFuncs {
   498  	if len(a.ptrBuf) == 0 {
   499  		a.allocator.AdjustMemoryUsage(int64(hashAggregatorAllocSize * sizeOfHashAggFuncsPtr))
   500  		a.ptrBuf = make([]*hashAggFuncs, hashAggregatorAllocSize)
   501  	}
   502  	// Since we don't expect a lot of hash collisions we only give out small
   503  	// amount of memory here.
   504  	ret := a.ptrBuf[0:0:1]
   505  	a.ptrBuf = a.ptrBuf[1:]
   506  	return ret
   507  }