github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/colexec/ordered_aggregator.go (about) 1 // Copyright 2018 The Cockroach Authors. 2 // 3 // Use of this software is governed by the Business Source License 4 // included in the file licenses/BSL.txt. 5 // 6 // As of the Change Date specified in that file, in accordance with 7 // the Business Source License, use of this software will be governed 8 // by the Apache License, Version 2.0, included in the file 9 // licenses/APL.txt. 10 11 package colexec 12 13 import ( 14 "context" 15 16 "github.com/cockroachdb/cockroach/pkg/col/coldata" 17 "github.com/cockroachdb/cockroach/pkg/sql/colexecbase" 18 "github.com/cockroachdb/cockroach/pkg/sql/colmem" 19 "github.com/cockroachdb/cockroach/pkg/sql/execinfrapb" 20 "github.com/cockroachdb/cockroach/pkg/sql/types" 21 "github.com/cockroachdb/errors" 22 ) 23 24 // orderedAggregator is an aggregator that performs arbitrary aggregations on 25 // input ordered by a set of grouping columns. Before performing any 26 // aggregations, the aggregator sets up a chain of distinct operators that will 27 // produce a vector of booleans (referenced in groupCol) that specifies whether 28 // or not the corresponding columns in the input batch are part of a new group. 29 // The memory is modified by the distinct operator flow. 30 // Every aggregate function will change the shape of the data. i.e. a new column 31 // value will be output for each input group. Since the number of input groups 32 // is variable and the number of output values is constant, care must be taken 33 // not to overflow the output buffer. To avoid having to perform bounds checks 34 // for the aggregate functions, the aggregator allocates twice the size of the 35 // input batch for the functions to write to. Before the next batch is 36 // processed, the aggregator checks what index the functions are outputting to. 37 // If greater than the expected output batch size by downstream operators, the 38 // overflow values are copied to the start of the batch. Since the input batch 39 // size is not necessarily the same as the output batch size, more than one copy 40 // and return must be performed until the aggregator is in a state where its 41 // functions are in a state where the output indices would not overflow the 42 // output batch if a worst case input batch is encountered (one where every 43 // value is part of a new group). 44 type orderedAggregator struct { 45 OneInputNode 46 47 allocator *colmem.Allocator 48 done bool 49 50 aggCols [][]uint32 51 aggTypes [][]*types.T 52 53 outputTypes []*types.T 54 55 // scratch is the Batch to output and variables related to it. Aggregate 56 // function operators write directly to this output batch. 57 scratch struct { 58 coldata.Batch 59 // shouldResetInternalBatch keeps track of whether the scratch.Batch should 60 // be reset. It is false in cases where we have overflow results still to 61 // return and therefore do not want to modify the batch. 62 shouldResetInternalBatch bool 63 // resumeIdx is the index at which the aggregation functions should start 64 // writing to on the next iteration of Next(). 65 resumeIdx int 66 // inputSize and outputSize are 2*coldata.BatchSize() and 67 // coldata.BatchSize(), respectively, by default but can be other values 68 // for tests. 69 inputSize int 70 outputSize int 71 } 72 73 // unsafeBatch is a coldata.Batch returned when only a subset of the 74 // scratch.Batch results is returned (i.e. work needs to be resumed on the 75 // next Next call). The values to return are copied into this batch to protect 76 // against downstream modification of the internal batch. 77 unsafeBatch coldata.Batch 78 79 // groupCol is the slice that aggregateFuncs use to determine whether a value 80 // is part of the current aggregation group. See aggregateFunc.Init for more 81 // information. 82 groupCol []bool 83 // aggregateFuncs are the aggregator's aggregate function operators. 84 aggregateFuncs []aggregateFunc 85 // isScalar indicates whether an aggregator is in scalar context. 86 isScalar bool 87 // seenNonEmptyBatch indicates whether a non-empty input batch has been 88 // observed. 89 seenNonEmptyBatch bool 90 } 91 92 var _ colexecbase.Operator = &orderedAggregator{} 93 94 // NewOrderedAggregator creates an ordered aggregator on the given grouping 95 // columns. aggCols is a slice where each index represents a new aggregation 96 // function. The slice at that index specifies the columns of the input batch 97 // that the aggregate function should work on. 98 func NewOrderedAggregator( 99 allocator *colmem.Allocator, 100 input colexecbase.Operator, 101 typs []*types.T, 102 aggFns []execinfrapb.AggregatorSpec_Func, 103 groupCols []uint32, 104 aggCols [][]uint32, 105 isScalar bool, 106 ) (colexecbase.Operator, error) { 107 if len(aggFns) != len(aggCols) { 108 return nil, 109 errors.Errorf( 110 "mismatched aggregation lengths: aggFns(%d), aggCols(%d)", 111 len(aggFns), 112 len(aggCols), 113 ) 114 } 115 116 aggTypes := extractAggTypes(aggCols, typs) 117 118 op, groupCol, err := OrderedDistinctColsToOperators(input, groupCols, typs) 119 if err != nil { 120 return nil, err 121 } 122 123 a := &orderedAggregator{} 124 if len(groupCols) == 0 { 125 // If there were no groupCols, we can't rely on the distinct operators to 126 // mark the first row as distinct, so we have to do it ourselves. Set up a 127 // oneShotOp to set the first row to distinct. 128 op = &oneShotOp{ 129 OneInputNode: NewOneInputNode(op), 130 fn: func(batch coldata.Batch) { 131 if batch.Length() == 0 { 132 return 133 } 134 if sel := batch.Selection(); sel != nil { 135 groupCol[sel[0]] = true 136 } else { 137 groupCol[0] = true 138 } 139 }, 140 outputSourceRef: &a.input, 141 } 142 } 143 144 *a = orderedAggregator{ 145 OneInputNode: NewOneInputNode(op), 146 147 allocator: allocator, 148 aggCols: aggCols, 149 aggTypes: aggTypes, 150 groupCol: groupCol, 151 isScalar: isScalar, 152 } 153 154 // We will be reusing the same aggregate functions, so we use 1 as the 155 // allocation size. 156 funcsAlloc, err := newAggregateFuncsAlloc(a.allocator, aggTypes, aggFns, 1 /* allocSize */) 157 if err != nil { 158 return nil, errors.AssertionFailedf( 159 "this error should have been checked in isAggregateSupported\n%+v", err, 160 ) 161 } 162 a.aggregateFuncs = funcsAlloc.makeAggregateFuncs() 163 a.outputTypes, err = makeAggregateFuncsOutputTypes(aggTypes, aggFns) 164 if err != nil { 165 return nil, errors.AssertionFailedf( 166 "this error should have been checked in isAggregateSupported\n%+v", err, 167 ) 168 } 169 170 return a, nil 171 } 172 173 func (a *orderedAggregator) initWithOutputBatchSize(outputSize int) { 174 a.initWithInputAndOutputBatchSize(coldata.BatchSize(), outputSize) 175 } 176 177 func (a *orderedAggregator) initWithInputAndOutputBatchSize(inputSize, outputSize int) { 178 a.input.Init() 179 180 // Twice the input batchSize is allocated to avoid having to check for 181 // overflow when outputting. 182 a.scratch.inputSize = inputSize * 2 183 a.scratch.outputSize = outputSize 184 a.scratch.Batch = a.allocator.NewMemBatchWithSize(a.outputTypes, a.scratch.inputSize) 185 for i := 0; i < len(a.outputTypes); i++ { 186 vec := a.scratch.ColVec(i) 187 a.aggregateFuncs[i].Init(a.groupCol, vec) 188 } 189 a.unsafeBatch = a.allocator.NewMemBatchWithSize(a.outputTypes, outputSize) 190 } 191 192 func (a *orderedAggregator) Init() { 193 a.initWithInputAndOutputBatchSize(coldata.BatchSize(), coldata.BatchSize()) 194 } 195 196 func (a *orderedAggregator) Next(ctx context.Context) coldata.Batch { 197 a.unsafeBatch.ResetInternalBatch() 198 if a.scratch.shouldResetInternalBatch { 199 a.scratch.ResetInternalBatch() 200 a.scratch.shouldResetInternalBatch = false 201 } 202 if a.done { 203 a.scratch.SetLength(0) 204 return a.scratch 205 } 206 if a.scratch.resumeIdx >= a.scratch.outputSize { 207 // Copy the second part of the output batch into the first and resume from 208 // there. 209 newResumeIdx := a.scratch.resumeIdx - a.scratch.outputSize 210 a.allocator.PerformOperation(a.scratch.ColVecs(), func() { 211 for i := 0; i < len(a.outputTypes); i++ { 212 vec := a.scratch.ColVec(i) 213 // According to the aggregate function interface contract, the value at 214 // the current index must also be copied. 215 // Note that we're using Append here instead of Copy because we want the 216 // "truncation" behavior, i.e. we want to copy over the remaining tuples 217 // such the "lengths" of the vectors are equal to the number of copied 218 // elements. 219 vec.Append( 220 coldata.SliceArgs{ 221 Src: vec, 222 DestIdx: 0, 223 SrcStartIdx: a.scratch.outputSize, 224 SrcEndIdx: a.scratch.resumeIdx + 1, 225 }, 226 ) 227 // Now we need to restore the desired length for the Vec. 228 vec.SetLength(a.scratch.inputSize) 229 a.aggregateFuncs[i].SetOutputIndex(newResumeIdx) 230 } 231 }) 232 a.scratch.resumeIdx = newResumeIdx 233 if a.scratch.resumeIdx >= a.scratch.outputSize { 234 // We still have overflow output values. 235 a.scratch.SetLength(a.scratch.outputSize) 236 a.allocator.PerformOperation(a.unsafeBatch.ColVecs(), func() { 237 for i := 0; i < len(a.outputTypes); i++ { 238 a.unsafeBatch.ColVec(i).Copy( 239 coldata.CopySliceArgs{ 240 SliceArgs: coldata.SliceArgs{ 241 Src: a.scratch.ColVec(i), 242 SrcStartIdx: 0, 243 SrcEndIdx: a.scratch.Length(), 244 }, 245 }, 246 ) 247 } 248 a.unsafeBatch.SetLength(a.scratch.Length()) 249 }) 250 a.scratch.shouldResetInternalBatch = false 251 return a.unsafeBatch 252 } 253 } 254 255 for a.scratch.resumeIdx < a.scratch.outputSize { 256 batch := a.input.Next(ctx) 257 a.seenNonEmptyBatch = a.seenNonEmptyBatch || batch.Length() > 0 258 if !a.seenNonEmptyBatch { 259 // The input has zero rows. 260 if a.isScalar { 261 for _, fn := range a.aggregateFuncs { 262 fn.HandleEmptyInputScalar() 263 } 264 // All aggregate functions will output a single value. 265 a.scratch.resumeIdx = 1 266 } else { 267 // There should be no output in non-scalar context for all aggregate 268 // functions. 269 a.scratch.resumeIdx = 0 270 } 271 } else { 272 if batch.Length() > 0 { 273 for i, fn := range a.aggregateFuncs { 274 fn.Compute(batch, a.aggCols[i]) 275 } 276 } else { 277 for _, fn := range a.aggregateFuncs { 278 fn.Flush() 279 } 280 } 281 a.scratch.resumeIdx = a.aggregateFuncs[0].CurrentOutputIndex() 282 } 283 if batch.Length() == 0 { 284 a.done = true 285 break 286 } 287 // zero out a.groupCol. This is necessary because distinct ORs the 288 // uniqueness of a value with the groupCol, allowing the operators to be 289 // linked. 290 copy(a.groupCol, zeroBoolColumn) 291 } 292 293 batchToReturn := a.scratch.Batch 294 if a.scratch.resumeIdx > a.scratch.outputSize { 295 a.scratch.SetLength(a.scratch.outputSize) 296 a.allocator.PerformOperation(a.unsafeBatch.ColVecs(), func() { 297 for i := 0; i < len(a.outputTypes); i++ { 298 a.unsafeBatch.ColVec(i).Copy( 299 coldata.CopySliceArgs{ 300 SliceArgs: coldata.SliceArgs{ 301 Src: a.scratch.ColVec(i), 302 SrcStartIdx: 0, 303 SrcEndIdx: a.scratch.Length(), 304 }, 305 }, 306 ) 307 } 308 a.unsafeBatch.SetLength(a.scratch.Length()) 309 }) 310 batchToReturn = a.unsafeBatch 311 a.scratch.shouldResetInternalBatch = false 312 } else { 313 a.scratch.SetLength(a.scratch.resumeIdx) 314 a.scratch.shouldResetInternalBatch = true 315 } 316 317 return batchToReturn 318 } 319 320 // reset resets the orderedAggregator for another run. Primarily used for 321 // benchmarks. 322 func (a *orderedAggregator) reset(ctx context.Context) { 323 if r, ok := a.input.(resetter); ok { 324 r.reset(ctx) 325 } 326 a.done = false 327 a.seenNonEmptyBatch = false 328 a.scratch.resumeIdx = 0 329 for _, fn := range a.aggregateFuncs { 330 fn.Reset() 331 } 332 }