github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/aggregation/window_partition.go (about) 1 // Copyright 2022 DoltHub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package aggregation 16 17 import ( 18 "errors" 19 "io" 20 "sort" 21 22 "github.com/dolthub/go-mysql-server/sql" 23 "github.com/dolthub/go-mysql-server/sql/expression" 24 ) 25 26 var ErrNoPartitions = errors.New("no partitions") 27 28 // Aggregation comprises a sql.WindowFunction and a companion sql.WindowFramer. 29 // A parent WindowPartitionIter feeds [fn] with intervals from the [framer]. 30 // Iteration logic is divided between [fn] and [framer] depending on context. 31 // For example, some aggregation functions like PercentRank and CountAgg track peer 32 // groups within a partition, more state than the framer provides. 33 type Aggregation struct { 34 fn sql.WindowFunction 35 framer sql.WindowFramer 36 } 37 38 func NewAggregation(a sql.WindowFunction, f sql.WindowFramer) *Aggregation { 39 return &Aggregation{fn: a, framer: f} 40 } 41 42 // startPartition disposes and recreates [framer] and resets the internal state of the aggregation [fn]. 43 func (a *Aggregation) startPartition(ctx *sql.Context, interval sql.WindowInterval, buf sql.WindowBuffer) error { 44 err := a.fn.StartPartition(ctx, interval, buf) 45 if err != nil { 46 return err 47 } 48 a.framer, err = a.framer.NewFramer(interval) 49 if err != nil { 50 return err 51 } 52 return nil 53 } 54 55 // WindowPartition is an Aggregation set with unique partition and sorting keys. 56 // There may be several WindowPartitions in one query, but each has unique key set. 57 // A WindowPartitionIter is used to evaluate a WindowPartition with a specific sql.RowIter. 58 type WindowPartition struct { 59 PartitionBy []sql.Expression 60 SortBy sql.SortFields 61 Aggs []*Aggregation 62 } 63 64 func NewWindowPartition(partitionBy []sql.Expression, sortBy sql.SortFields, aggs []*Aggregation) *WindowPartition { 65 return &WindowPartition{ 66 PartitionBy: partitionBy, 67 SortBy: sortBy, 68 Aggs: aggs, 69 } 70 } 71 72 func (w *WindowPartition) AddAggregation(agg *Aggregation) { 73 w.Aggs = append(w.Aggs, agg) 74 } 75 76 // WindowPartitionIter evaluates a WindowPartition with a sql.RowIter child. 77 // A parent WindowIter is expected to maintain the projection ordering for 78 // WindowPartition output columns. 79 // 80 // WindowPartitionIter will return rows sorted in the same order 81 // generated by [child]. This is accomplished privately by appending 82 // the sort ordering index to [i.input] rows during materializeInput, 83 // and removing after sortAndFilterOutput. 84 // 85 // Next currently materializes [i.input] and [i.output] before 86 // returning the first result, regardless of Limit or other expressions. 87 type WindowPartitionIter struct { 88 w *WindowPartition 89 child sql.RowIter 90 input, output sql.WindowBuffer 91 92 pos int 93 outputOrderingPos int 94 outputOrdering []int 95 96 partitions []sql.WindowInterval 97 currentPartition sql.WindowInterval 98 partitionIdx int 99 } 100 101 var _ sql.RowIter = (*WindowPartitionIter)(nil) 102 var _ sql.Disposable = (*WindowPartitionIter)(nil) 103 104 func NewWindowPartitionIter(windowBlock *WindowPartition) *WindowPartitionIter { 105 return &WindowPartitionIter{ 106 w: windowBlock, 107 partitionIdx: -1, 108 } 109 } 110 111 func (i *WindowPartitionIter) WindowBlock() *WindowPartition { 112 return i.w 113 } 114 115 func (i *WindowPartitionIter) Close(ctx *sql.Context) error { 116 i.Dispose() 117 i.input = nil 118 return nil 119 } 120 121 func (i *WindowPartitionIter) Dispose() { 122 for _, a := range i.w.Aggs { 123 a.fn.Dispose() 124 } 125 } 126 127 func (i *WindowPartitionIter) Next(ctx *sql.Context) (sql.Row, error) { 128 var err error 129 if i.output == nil { 130 i.input, i.outputOrdering, err = i.materializeInput(ctx) 131 if err != nil { 132 return nil, err 133 } 134 135 i.partitions, err = i.initializePartitions(ctx) 136 if err != nil { 137 return nil, err 138 } 139 140 i.output, err = i.materializeOutput(ctx) 141 if err != nil { 142 return nil, err 143 } 144 145 err = i.sortAndFilterOutput() 146 if err != nil { 147 return nil, err 148 } 149 } 150 151 if i.pos > len(i.output)-1 { 152 return nil, io.EOF 153 } 154 155 defer func() { i.pos++ }() 156 157 return i.output[i.pos], nil 158 } 159 160 // materializeInput empties the child iterator into a buffer and sorts by (WPK, WSK). Returns 161 // a sorted sql.WindowBuffer and a list of original row indices for resorting. 162 func (i *WindowPartitionIter) materializeInput(ctx *sql.Context) (sql.WindowBuffer, []int, error) { 163 input := make(sql.WindowBuffer, 0) 164 j := 0 165 for { 166 row, err := i.child.Next(ctx) 167 if err != nil { 168 if err == io.EOF { 169 break 170 } 171 return nil, nil, err 172 } 173 input = append(input, append(row, j)) 174 j++ 175 } 176 177 if len(input) == 0 { 178 return nil, nil, nil 179 } 180 181 // sort all rows by partition 182 sorter := &expression.Sorter{ 183 SortFields: append(partitionsToSortFields(i.w.PartitionBy), i.w.SortBy...), 184 Rows: input, 185 Ctx: ctx, 186 } 187 sort.Stable(sorter) 188 189 // maintain output sort ordering 190 // TODO: push sort above aggregation, makes this code unnecessarily complex 191 outputOrdering := make([]int, len(input)) 192 outputIdx := len(input[0]) - 1 193 for k, row := range input { 194 outputOrdering[k], input[k] = row[outputIdx].(int), row[:outputIdx] 195 } 196 197 return input, outputOrdering, nil 198 } 199 200 // initializePartitions walks the [i.input] buffer using [i.PartitionBy] and 201 // returns a list of sql.WindowInterval [partition]s. 202 func (i *WindowPartitionIter) initializePartitions(ctx *sql.Context) ([]sql.WindowInterval, error) { 203 if len(i.input) == 0 { 204 // Some conditions require a default output for nil input rows. The 205 // empty partition lets window framing pass through one io.EOF to 206 // provide a default result before stopping for these cases. 207 return []sql.WindowInterval{{Start: 0, End: 0}}, nil 208 } 209 210 partitions := make([]sql.WindowInterval, 0) 211 startIdx := 0 212 var lastRow sql.Row 213 for j, row := range i.input { 214 newPart, err := isNewPartition(ctx, i.w.PartitionBy, lastRow, row) 215 if err != nil { 216 return nil, err 217 } 218 if newPart && j > startIdx { 219 partitions = append(partitions, sql.WindowInterval{Start: startIdx, End: j}) 220 startIdx = j 221 } 222 lastRow = row 223 } 224 225 if startIdx < len(i.input) { 226 partitions = append(partitions, sql.WindowInterval{Start: startIdx, End: len(i.input)}) 227 } 228 229 return partitions, nil 230 } 231 232 // materializeOutput evaluates and collects all aggregation results into an output sql.WindowBuffer. 233 // At this stage, result rows are appended with the original row index for resorting. The size of 234 // [i.output] will be smaller than [i.input] if the outer sql.Node is a plan.GroupBy with fewer partitions than rows. 235 func (i *WindowPartitionIter) materializeOutput(ctx *sql.Context) (sql.WindowBuffer, error) { 236 // handle nil input specially if no partition clause 237 // ex: COUNT(*) on nil rows returns 0, not nil 238 if len(i.input) == 0 && len(i.w.PartitionBy) > 0 { 239 return nil, io.EOF 240 } 241 242 output := make(sql.WindowBuffer, 0, len(i.input)) 243 var row sql.Row 244 var err error 245 for { 246 row, err = i.compute(ctx) 247 if errors.Is(err, io.EOF) { 248 break 249 } else if err != nil { 250 return nil, err 251 } 252 output = append(output, row) 253 } 254 255 return output, nil 256 } 257 258 // compute evaluates each function in [i.Aggs], returning the result as an sql.Row with 259 // the outputOrdering index appended, or an io.EOF error if we are finished iterating. 260 func (i *WindowPartitionIter) compute(ctx *sql.Context) (sql.Row, error) { 261 var row = make(sql.Row, len(i.w.Aggs)+1) 262 263 // each [agg] has its own [agg.framer] that is globally positioned 264 // but updated independently. This allows aggregations with the same 265 // partition and sorting to have different framing behavior. 266 for j, agg := range i.w.Aggs { 267 interval, err := agg.framer.Next(ctx, i.input) 268 if errors.Is(err, io.EOF) { 269 err = i.nextPartition(ctx) 270 if err != nil { 271 return nil, err 272 } 273 interval, err = agg.framer.Next(ctx, i.input) 274 if err != nil { 275 return nil, err 276 } 277 } 278 row[j] = agg.fn.Compute(ctx, interval, i.input) 279 } 280 281 // TODO: move sort by above aggregation 282 if len(i.outputOrdering) > 0 { 283 row[len(i.w.Aggs)] = i.outputOrdering[i.outputOrderingPos] 284 } 285 286 i.outputOrderingPos++ 287 return row, nil 288 } 289 290 // sortAndFilterOutput in-place sorts the [i.output] buffer using the last 291 // value in every row as the sort index. 292 func (i *WindowPartitionIter) sortAndFilterOutput() error { 293 // TODO: move sort by above aggregations 294 // we could cycle sort this for windows (not group by, unless number 295 // of group by partitions = number of rows) 296 if len(i.output) == 0 { 297 return nil 298 } 299 300 originalOrderIdx := len(i.output[0]) - 1 301 sort.SliceStable(i.output, func(j, k int) bool { 302 return i.output[j][originalOrderIdx].(int) < i.output[k][originalOrderIdx].(int) 303 }) 304 305 for j, row := range i.output { 306 i.output[j] = row[:originalOrderIdx] 307 } 308 309 return nil 310 } 311 312 func (i *WindowPartitionIter) nextPartition(ctx *sql.Context) error { 313 if len(i.partitions) == 0 { 314 return ErrNoPartitions 315 } 316 317 if i.partitionIdx < 0 { 318 i.partitionIdx = 0 319 } else { 320 i.partitionIdx++ 321 } 322 323 if i.partitionIdx > len(i.partitions)-1 { 324 return io.EOF 325 } 326 327 i.currentPartition = i.partitions[i.partitionIdx] 328 i.outputOrderingPos = i.currentPartition.Start 329 330 var err error 331 for _, a := range i.w.Aggs { 332 err = a.startPartition(ctx, i.currentPartition, i.input) 333 if err != nil { 334 return err 335 } 336 } 337 338 return nil 339 } 340 341 func partitionsToSortFields(partitionExprs []sql.Expression) sql.SortFields { 342 sfs := make(sql.SortFields, len(partitionExprs)) 343 for i, expr := range partitionExprs { 344 sfs[i] = sql.SortField{ 345 Column: expr, 346 Order: sql.Ascending, 347 } 348 } 349 return sfs 350 } 351 352 func isNewPartition(ctx *sql.Context, partitionBy []sql.Expression, last sql.Row, row sql.Row) (bool, error) { 353 if len(last) == 0 { 354 return true, nil 355 } 356 357 if len(partitionBy) == 0 { 358 return false, nil 359 } 360 361 lastExp, _, err := evalExprs(ctx, partitionBy, last) 362 if err != nil { 363 return false, err 364 } 365 366 thisExp, _, err := evalExprs(ctx, partitionBy, row) 367 if err != nil { 368 return false, err 369 } 370 371 for i, expr := range partitionBy { 372 cmp, err := expr.Type().Compare(lastExp[i], thisExp[i]) 373 if err != nil { 374 return false, err 375 } 376 if cmp != 0 { 377 return true, nil 378 } 379 } 380 381 return false, nil 382 }