github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/colexec/ordered_aggregator.go (about)

     1  // Copyright 2018 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package colexec
    12  
    13  import (
    14  	"context"
    15  
    16  	"github.com/cockroachdb/cockroach/pkg/col/coldata"
    17  	"github.com/cockroachdb/cockroach/pkg/sql/colexecbase"
    18  	"github.com/cockroachdb/cockroach/pkg/sql/colmem"
    19  	"github.com/cockroachdb/cockroach/pkg/sql/execinfrapb"
    20  	"github.com/cockroachdb/cockroach/pkg/sql/types"
    21  	"github.com/cockroachdb/errors"
    22  )
    23  
    24  // orderedAggregator is an aggregator that performs arbitrary aggregations on
    25  // input ordered by a set of grouping columns. Before performing any
    26  // aggregations, the aggregator sets up a chain of distinct operators that will
    27  // produce a vector of booleans (referenced in groupCol) that specifies whether
    28  // or not the corresponding columns in the input batch are part of a new group.
    29  // The memory is modified by the distinct operator flow.
    30  // Every aggregate function will change the shape of the data. i.e. a new column
    31  // value will be output for each input group. Since the number of input groups
    32  // is variable and the number of output values is constant, care must be taken
    33  // not to overflow the output buffer. To avoid having to perform bounds checks
    34  // for the aggregate functions, the aggregator allocates twice the size of the
    35  // input batch for the functions to write to. Before the next batch is
    36  // processed, the aggregator checks what index the functions are outputting to.
    37  // If greater than the expected output batch size by downstream operators, the
    38  // overflow values are copied to the start of the batch. Since the input batch
    39  // size is not necessarily the same as the output batch size, more than one copy
    40  // and return must be performed until the aggregator is in a state where its
    41  // functions are in a state where the output indices would not overflow the
    42  // output batch if a worst case input batch is encountered (one where every
    43  // value is part of a new group).
    44  type orderedAggregator struct {
    45  	OneInputNode
    46  
    47  	allocator *colmem.Allocator
    48  	done      bool
    49  
    50  	aggCols  [][]uint32
    51  	aggTypes [][]*types.T
    52  
    53  	outputTypes []*types.T
    54  
    55  	// scratch is the Batch to output and variables related to it. Aggregate
    56  	// function operators write directly to this output batch.
    57  	scratch struct {
    58  		coldata.Batch
    59  		// shouldResetInternalBatch keeps track of whether the scratch.Batch should
    60  		// be reset. It is false in cases where we have overflow results still to
    61  		// return and therefore do not want to modify the batch.
    62  		shouldResetInternalBatch bool
    63  		// resumeIdx is the index at which the aggregation functions should start
    64  		// writing to on the next iteration of Next().
    65  		resumeIdx int
    66  		// inputSize and outputSize are 2*coldata.BatchSize() and
    67  		// coldata.BatchSize(), respectively, by default but can be other values
    68  		// for tests.
    69  		inputSize  int
    70  		outputSize int
    71  	}
    72  
    73  	// unsafeBatch is a coldata.Batch returned when only a subset of the
    74  	// scratch.Batch results is returned (i.e. work needs to be resumed on the
    75  	// next Next call). The values to return are copied into this batch to protect
    76  	// against downstream modification of the internal batch.
    77  	unsafeBatch coldata.Batch
    78  
    79  	// groupCol is the slice that aggregateFuncs use to determine whether a value
    80  	// is part of the current aggregation group. See aggregateFunc.Init for more
    81  	// information.
    82  	groupCol []bool
    83  	// aggregateFuncs are the aggregator's aggregate function operators.
    84  	aggregateFuncs []aggregateFunc
    85  	// isScalar indicates whether an aggregator is in scalar context.
    86  	isScalar bool
    87  	// seenNonEmptyBatch indicates whether a non-empty input batch has been
    88  	// observed.
    89  	seenNonEmptyBatch bool
    90  }
    91  
    92  var _ colexecbase.Operator = &orderedAggregator{}
    93  
    94  // NewOrderedAggregator creates an ordered aggregator on the given grouping
    95  // columns. aggCols is a slice where each index represents a new aggregation
    96  // function. The slice at that index specifies the columns of the input batch
    97  // that the aggregate function should work on.
    98  func NewOrderedAggregator(
    99  	allocator *colmem.Allocator,
   100  	input colexecbase.Operator,
   101  	typs []*types.T,
   102  	aggFns []execinfrapb.AggregatorSpec_Func,
   103  	groupCols []uint32,
   104  	aggCols [][]uint32,
   105  	isScalar bool,
   106  ) (colexecbase.Operator, error) {
   107  	if len(aggFns) != len(aggCols) {
   108  		return nil,
   109  			errors.Errorf(
   110  				"mismatched aggregation lengths: aggFns(%d), aggCols(%d)",
   111  				len(aggFns),
   112  				len(aggCols),
   113  			)
   114  	}
   115  
   116  	aggTypes := extractAggTypes(aggCols, typs)
   117  
   118  	op, groupCol, err := OrderedDistinctColsToOperators(input, groupCols, typs)
   119  	if err != nil {
   120  		return nil, err
   121  	}
   122  
   123  	a := &orderedAggregator{}
   124  	if len(groupCols) == 0 {
   125  		// If there were no groupCols, we can't rely on the distinct operators to
   126  		// mark the first row as distinct, so we have to do it ourselves. Set up a
   127  		// oneShotOp to set the first row to distinct.
   128  		op = &oneShotOp{
   129  			OneInputNode: NewOneInputNode(op),
   130  			fn: func(batch coldata.Batch) {
   131  				if batch.Length() == 0 {
   132  					return
   133  				}
   134  				if sel := batch.Selection(); sel != nil {
   135  					groupCol[sel[0]] = true
   136  				} else {
   137  					groupCol[0] = true
   138  				}
   139  			},
   140  			outputSourceRef: &a.input,
   141  		}
   142  	}
   143  
   144  	*a = orderedAggregator{
   145  		OneInputNode: NewOneInputNode(op),
   146  
   147  		allocator: allocator,
   148  		aggCols:   aggCols,
   149  		aggTypes:  aggTypes,
   150  		groupCol:  groupCol,
   151  		isScalar:  isScalar,
   152  	}
   153  
   154  	// We will be reusing the same aggregate functions, so we use 1 as the
   155  	// allocation size.
   156  	funcsAlloc, err := newAggregateFuncsAlloc(a.allocator, aggTypes, aggFns, 1 /* allocSize */)
   157  	if err != nil {
   158  		return nil, errors.AssertionFailedf(
   159  			"this error should have been checked in isAggregateSupported\n%+v", err,
   160  		)
   161  	}
   162  	a.aggregateFuncs = funcsAlloc.makeAggregateFuncs()
   163  	a.outputTypes, err = makeAggregateFuncsOutputTypes(aggTypes, aggFns)
   164  	if err != nil {
   165  		return nil, errors.AssertionFailedf(
   166  			"this error should have been checked in isAggregateSupported\n%+v", err,
   167  		)
   168  	}
   169  
   170  	return a, nil
   171  }
   172  
   173  func (a *orderedAggregator) initWithOutputBatchSize(outputSize int) {
   174  	a.initWithInputAndOutputBatchSize(coldata.BatchSize(), outputSize)
   175  }
   176  
   177  func (a *orderedAggregator) initWithInputAndOutputBatchSize(inputSize, outputSize int) {
   178  	a.input.Init()
   179  
   180  	// Twice the input batchSize is allocated to avoid having to check for
   181  	// overflow when outputting.
   182  	a.scratch.inputSize = inputSize * 2
   183  	a.scratch.outputSize = outputSize
   184  	a.scratch.Batch = a.allocator.NewMemBatchWithSize(a.outputTypes, a.scratch.inputSize)
   185  	for i := 0; i < len(a.outputTypes); i++ {
   186  		vec := a.scratch.ColVec(i)
   187  		a.aggregateFuncs[i].Init(a.groupCol, vec)
   188  	}
   189  	a.unsafeBatch = a.allocator.NewMemBatchWithSize(a.outputTypes, outputSize)
   190  }
   191  
   192  func (a *orderedAggregator) Init() {
   193  	a.initWithInputAndOutputBatchSize(coldata.BatchSize(), coldata.BatchSize())
   194  }
   195  
   196  func (a *orderedAggregator) Next(ctx context.Context) coldata.Batch {
   197  	a.unsafeBatch.ResetInternalBatch()
   198  	if a.scratch.shouldResetInternalBatch {
   199  		a.scratch.ResetInternalBatch()
   200  		a.scratch.shouldResetInternalBatch = false
   201  	}
   202  	if a.done {
   203  		a.scratch.SetLength(0)
   204  		return a.scratch
   205  	}
   206  	if a.scratch.resumeIdx >= a.scratch.outputSize {
   207  		// Copy the second part of the output batch into the first and resume from
   208  		// there.
   209  		newResumeIdx := a.scratch.resumeIdx - a.scratch.outputSize
   210  		a.allocator.PerformOperation(a.scratch.ColVecs(), func() {
   211  			for i := 0; i < len(a.outputTypes); i++ {
   212  				vec := a.scratch.ColVec(i)
   213  				// According to the aggregate function interface contract, the value at
   214  				// the current index must also be copied.
   215  				// Note that we're using Append here instead of Copy because we want the
   216  				// "truncation" behavior, i.e. we want to copy over the remaining tuples
   217  				// such the "lengths" of the vectors are equal to the number of copied
   218  				// elements.
   219  				vec.Append(
   220  					coldata.SliceArgs{
   221  						Src:         vec,
   222  						DestIdx:     0,
   223  						SrcStartIdx: a.scratch.outputSize,
   224  						SrcEndIdx:   a.scratch.resumeIdx + 1,
   225  					},
   226  				)
   227  				// Now we need to restore the desired length for the Vec.
   228  				vec.SetLength(a.scratch.inputSize)
   229  				a.aggregateFuncs[i].SetOutputIndex(newResumeIdx)
   230  			}
   231  		})
   232  		a.scratch.resumeIdx = newResumeIdx
   233  		if a.scratch.resumeIdx >= a.scratch.outputSize {
   234  			// We still have overflow output values.
   235  			a.scratch.SetLength(a.scratch.outputSize)
   236  			a.allocator.PerformOperation(a.unsafeBatch.ColVecs(), func() {
   237  				for i := 0; i < len(a.outputTypes); i++ {
   238  					a.unsafeBatch.ColVec(i).Copy(
   239  						coldata.CopySliceArgs{
   240  							SliceArgs: coldata.SliceArgs{
   241  								Src:         a.scratch.ColVec(i),
   242  								SrcStartIdx: 0,
   243  								SrcEndIdx:   a.scratch.Length(),
   244  							},
   245  						},
   246  					)
   247  				}
   248  				a.unsafeBatch.SetLength(a.scratch.Length())
   249  			})
   250  			a.scratch.shouldResetInternalBatch = false
   251  			return a.unsafeBatch
   252  		}
   253  	}
   254  
   255  	for a.scratch.resumeIdx < a.scratch.outputSize {
   256  		batch := a.input.Next(ctx)
   257  		a.seenNonEmptyBatch = a.seenNonEmptyBatch || batch.Length() > 0
   258  		if !a.seenNonEmptyBatch {
   259  			// The input has zero rows.
   260  			if a.isScalar {
   261  				for _, fn := range a.aggregateFuncs {
   262  					fn.HandleEmptyInputScalar()
   263  				}
   264  				// All aggregate functions will output a single value.
   265  				a.scratch.resumeIdx = 1
   266  			} else {
   267  				// There should be no output in non-scalar context for all aggregate
   268  				// functions.
   269  				a.scratch.resumeIdx = 0
   270  			}
   271  		} else {
   272  			if batch.Length() > 0 {
   273  				for i, fn := range a.aggregateFuncs {
   274  					fn.Compute(batch, a.aggCols[i])
   275  				}
   276  			} else {
   277  				for _, fn := range a.aggregateFuncs {
   278  					fn.Flush()
   279  				}
   280  			}
   281  			a.scratch.resumeIdx = a.aggregateFuncs[0].CurrentOutputIndex()
   282  		}
   283  		if batch.Length() == 0 {
   284  			a.done = true
   285  			break
   286  		}
   287  		// zero out a.groupCol. This is necessary because distinct ORs the
   288  		// uniqueness of a value with the groupCol, allowing the operators to be
   289  		// linked.
   290  		copy(a.groupCol, zeroBoolColumn)
   291  	}
   292  
   293  	batchToReturn := a.scratch.Batch
   294  	if a.scratch.resumeIdx > a.scratch.outputSize {
   295  		a.scratch.SetLength(a.scratch.outputSize)
   296  		a.allocator.PerformOperation(a.unsafeBatch.ColVecs(), func() {
   297  			for i := 0; i < len(a.outputTypes); i++ {
   298  				a.unsafeBatch.ColVec(i).Copy(
   299  					coldata.CopySliceArgs{
   300  						SliceArgs: coldata.SliceArgs{
   301  							Src:         a.scratch.ColVec(i),
   302  							SrcStartIdx: 0,
   303  							SrcEndIdx:   a.scratch.Length(),
   304  						},
   305  					},
   306  				)
   307  			}
   308  			a.unsafeBatch.SetLength(a.scratch.Length())
   309  		})
   310  		batchToReturn = a.unsafeBatch
   311  		a.scratch.shouldResetInternalBatch = false
   312  	} else {
   313  		a.scratch.SetLength(a.scratch.resumeIdx)
   314  		a.scratch.shouldResetInternalBatch = true
   315  	}
   316  
   317  	return batchToReturn
   318  }
   319  
   320  // reset resets the orderedAggregator for another run. Primarily used for
   321  // benchmarks.
   322  func (a *orderedAggregator) reset(ctx context.Context) {
   323  	if r, ok := a.input.(resetter); ok {
   324  		r.reset(ctx)
   325  	}
   326  	a.done = false
   327  	a.seenNonEmptyBatch = false
   328  	a.scratch.resumeIdx = 0
   329  	for _, fn := range a.aggregateFuncs {
   330  		fn.Reset()
   331  	}
   332  }