github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/aggregation/window_partition.go (about)

     1  // Copyright 2022 DoltHub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package aggregation
    16  
    17  import (
    18  	"errors"
    19  	"io"
    20  	"sort"
    21  
    22  	"github.com/dolthub/go-mysql-server/sql"
    23  	"github.com/dolthub/go-mysql-server/sql/expression"
    24  )
    25  
    26  var ErrNoPartitions = errors.New("no partitions")
    27  
    28  // Aggregation comprises a sql.WindowFunction and a companion sql.WindowFramer.
    29  // A parent WindowPartitionIter feeds [fn] with intervals from the [framer].
    30  // Iteration logic is divided between [fn] and [framer] depending on context.
    31  // For example, some aggregation functions like PercentRank and CountAgg track peer
    32  // groups within a partition, more state than the framer provides.
    33  type Aggregation struct {
    34  	fn     sql.WindowFunction
    35  	framer sql.WindowFramer
    36  }
    37  
    38  func NewAggregation(a sql.WindowFunction, f sql.WindowFramer) *Aggregation {
    39  	return &Aggregation{fn: a, framer: f}
    40  }
    41  
    42  // startPartition disposes and recreates [framer] and resets the internal state of the aggregation [fn].
    43  func (a *Aggregation) startPartition(ctx *sql.Context, interval sql.WindowInterval, buf sql.WindowBuffer) error {
    44  	err := a.fn.StartPartition(ctx, interval, buf)
    45  	if err != nil {
    46  		return err
    47  	}
    48  	a.framer, err = a.framer.NewFramer(interval)
    49  	if err != nil {
    50  		return err
    51  	}
    52  	return nil
    53  }
    54  
    55  // WindowPartition is an Aggregation set with unique partition and sorting keys.
    56  // There may be several WindowPartitions in one query, but each has unique key set.
    57  // A WindowPartitionIter is used to evaluate a WindowPartition with a specific sql.RowIter.
    58  type WindowPartition struct {
    59  	PartitionBy []sql.Expression
    60  	SortBy      sql.SortFields
    61  	Aggs        []*Aggregation
    62  }
    63  
    64  func NewWindowPartition(partitionBy []sql.Expression, sortBy sql.SortFields, aggs []*Aggregation) *WindowPartition {
    65  	return &WindowPartition{
    66  		PartitionBy: partitionBy,
    67  		SortBy:      sortBy,
    68  		Aggs:        aggs,
    69  	}
    70  }
    71  
    72  func (w *WindowPartition) AddAggregation(agg *Aggregation) {
    73  	w.Aggs = append(w.Aggs, agg)
    74  }
    75  
    76  // WindowPartitionIter evaluates a WindowPartition with a sql.RowIter child.
    77  // A parent WindowIter is expected to maintain the projection ordering for
    78  // WindowPartition output columns.
    79  //
    80  // WindowPartitionIter will return rows sorted in the same order
    81  // generated by [child]. This is accomplished privately by appending
    82  // the sort ordering index to [i.input] rows during materializeInput,
    83  // and removing after sortAndFilterOutput.
    84  //
    85  // Next currently materializes [i.input] and [i.output] before
    86  // returning the first result, regardless of Limit or other expressions.
    87  type WindowPartitionIter struct {
    88  	w             *WindowPartition
    89  	child         sql.RowIter
    90  	input, output sql.WindowBuffer
    91  
    92  	pos               int
    93  	outputOrderingPos int
    94  	outputOrdering    []int
    95  
    96  	partitions       []sql.WindowInterval
    97  	currentPartition sql.WindowInterval
    98  	partitionIdx     int
    99  }
   100  
   101  var _ sql.RowIter = (*WindowPartitionIter)(nil)
   102  var _ sql.Disposable = (*WindowPartitionIter)(nil)
   103  
   104  func NewWindowPartitionIter(windowBlock *WindowPartition) *WindowPartitionIter {
   105  	return &WindowPartitionIter{
   106  		w:            windowBlock,
   107  		partitionIdx: -1,
   108  	}
   109  }
   110  
   111  func (i *WindowPartitionIter) WindowBlock() *WindowPartition {
   112  	return i.w
   113  }
   114  
   115  func (i *WindowPartitionIter) Close(ctx *sql.Context) error {
   116  	i.Dispose()
   117  	i.input = nil
   118  	return nil
   119  }
   120  
   121  func (i *WindowPartitionIter) Dispose() {
   122  	for _, a := range i.w.Aggs {
   123  		a.fn.Dispose()
   124  	}
   125  }
   126  
   127  func (i *WindowPartitionIter) Next(ctx *sql.Context) (sql.Row, error) {
   128  	var err error
   129  	if i.output == nil {
   130  		i.input, i.outputOrdering, err = i.materializeInput(ctx)
   131  		if err != nil {
   132  			return nil, err
   133  		}
   134  
   135  		i.partitions, err = i.initializePartitions(ctx)
   136  		if err != nil {
   137  			return nil, err
   138  		}
   139  
   140  		i.output, err = i.materializeOutput(ctx)
   141  		if err != nil {
   142  			return nil, err
   143  		}
   144  
   145  		err = i.sortAndFilterOutput()
   146  		if err != nil {
   147  			return nil, err
   148  		}
   149  	}
   150  
   151  	if i.pos > len(i.output)-1 {
   152  		return nil, io.EOF
   153  	}
   154  
   155  	defer func() { i.pos++ }()
   156  
   157  	return i.output[i.pos], nil
   158  }
   159  
   160  // materializeInput empties the child iterator into a buffer and sorts by (WPK, WSK). Returns
   161  // a sorted sql.WindowBuffer and a list of original row indices for resorting.
   162  func (i *WindowPartitionIter) materializeInput(ctx *sql.Context) (sql.WindowBuffer, []int, error) {
   163  	input := make(sql.WindowBuffer, 0)
   164  	j := 0
   165  	for {
   166  		row, err := i.child.Next(ctx)
   167  		if err != nil {
   168  			if err == io.EOF {
   169  				break
   170  			}
   171  			return nil, nil, err
   172  		}
   173  		input = append(input, append(row, j))
   174  		j++
   175  	}
   176  
   177  	if len(input) == 0 {
   178  		return nil, nil, nil
   179  	}
   180  
   181  	// sort all rows by partition
   182  	sorter := &expression.Sorter{
   183  		SortFields: append(partitionsToSortFields(i.w.PartitionBy), i.w.SortBy...),
   184  		Rows:       input,
   185  		Ctx:        ctx,
   186  	}
   187  	sort.Stable(sorter)
   188  
   189  	// maintain output sort ordering
   190  	// TODO: push sort above aggregation, makes this code unnecessarily complex
   191  	outputOrdering := make([]int, len(input))
   192  	outputIdx := len(input[0]) - 1
   193  	for k, row := range input {
   194  		outputOrdering[k], input[k] = row[outputIdx].(int), row[:outputIdx]
   195  	}
   196  
   197  	return input, outputOrdering, nil
   198  }
   199  
   200  // initializePartitions walks the [i.input] buffer using [i.PartitionBy] and
   201  // returns a list of sql.WindowInterval [partition]s.
   202  func (i *WindowPartitionIter) initializePartitions(ctx *sql.Context) ([]sql.WindowInterval, error) {
   203  	if len(i.input) == 0 {
   204  		// Some conditions require a default output for nil input rows. The
   205  		// empty partition lets window framing pass through one io.EOF to
   206  		// provide a default result before stopping for these cases.
   207  		return []sql.WindowInterval{{Start: 0, End: 0}}, nil
   208  	}
   209  
   210  	partitions := make([]sql.WindowInterval, 0)
   211  	startIdx := 0
   212  	var lastRow sql.Row
   213  	for j, row := range i.input {
   214  		newPart, err := isNewPartition(ctx, i.w.PartitionBy, lastRow, row)
   215  		if err != nil {
   216  			return nil, err
   217  		}
   218  		if newPart && j > startIdx {
   219  			partitions = append(partitions, sql.WindowInterval{Start: startIdx, End: j})
   220  			startIdx = j
   221  		}
   222  		lastRow = row
   223  	}
   224  
   225  	if startIdx < len(i.input) {
   226  		partitions = append(partitions, sql.WindowInterval{Start: startIdx, End: len(i.input)})
   227  	}
   228  
   229  	return partitions, nil
   230  }
   231  
   232  // materializeOutput evaluates and collects all aggregation results into an output sql.WindowBuffer.
   233  // At this stage, result rows are appended with the original row index for resorting. The size of
   234  // [i.output] will be smaller than [i.input] if the outer sql.Node is a plan.GroupBy with fewer partitions than rows.
   235  func (i *WindowPartitionIter) materializeOutput(ctx *sql.Context) (sql.WindowBuffer, error) {
   236  	// handle nil input specially if no partition clause
   237  	// ex: COUNT(*) on nil rows returns 0, not nil
   238  	if len(i.input) == 0 && len(i.w.PartitionBy) > 0 {
   239  		return nil, io.EOF
   240  	}
   241  
   242  	output := make(sql.WindowBuffer, 0, len(i.input))
   243  	var row sql.Row
   244  	var err error
   245  	for {
   246  		row, err = i.compute(ctx)
   247  		if errors.Is(err, io.EOF) {
   248  			break
   249  		} else if err != nil {
   250  			return nil, err
   251  		}
   252  		output = append(output, row)
   253  	}
   254  
   255  	return output, nil
   256  }
   257  
   258  // compute evaluates each function in [i.Aggs], returning the result as an sql.Row with
   259  // the outputOrdering index appended, or an io.EOF error if we are finished iterating.
   260  func (i *WindowPartitionIter) compute(ctx *sql.Context) (sql.Row, error) {
   261  	var row = make(sql.Row, len(i.w.Aggs)+1)
   262  
   263  	// each [agg] has its own [agg.framer] that is globally positioned
   264  	// but updated independently. This allows aggregations with the same
   265  	// partition and sorting to have different framing behavior.
   266  	for j, agg := range i.w.Aggs {
   267  		interval, err := agg.framer.Next(ctx, i.input)
   268  		if errors.Is(err, io.EOF) {
   269  			err = i.nextPartition(ctx)
   270  			if err != nil {
   271  				return nil, err
   272  			}
   273  			interval, err = agg.framer.Next(ctx, i.input)
   274  			if err != nil {
   275  				return nil, err
   276  			}
   277  		}
   278  		row[j] = agg.fn.Compute(ctx, interval, i.input)
   279  	}
   280  
   281  	// TODO: move sort by above aggregation
   282  	if len(i.outputOrdering) > 0 {
   283  		row[len(i.w.Aggs)] = i.outputOrdering[i.outputOrderingPos]
   284  	}
   285  
   286  	i.outputOrderingPos++
   287  	return row, nil
   288  }
   289  
   290  // sortAndFilterOutput in-place sorts the [i.output] buffer using the last
   291  // value in every row as the sort index.
   292  func (i *WindowPartitionIter) sortAndFilterOutput() error {
   293  	// TODO: move sort by above aggregations
   294  	// we could cycle sort this for windows (not group by, unless number
   295  	// of group by partitions = number of rows)
   296  	if len(i.output) == 0 {
   297  		return nil
   298  	}
   299  
   300  	originalOrderIdx := len(i.output[0]) - 1
   301  	sort.SliceStable(i.output, func(j, k int) bool {
   302  		return i.output[j][originalOrderIdx].(int) < i.output[k][originalOrderIdx].(int)
   303  	})
   304  
   305  	for j, row := range i.output {
   306  		i.output[j] = row[:originalOrderIdx]
   307  	}
   308  
   309  	return nil
   310  }
   311  
   312  func (i *WindowPartitionIter) nextPartition(ctx *sql.Context) error {
   313  	if len(i.partitions) == 0 {
   314  		return ErrNoPartitions
   315  	}
   316  
   317  	if i.partitionIdx < 0 {
   318  		i.partitionIdx = 0
   319  	} else {
   320  		i.partitionIdx++
   321  	}
   322  
   323  	if i.partitionIdx > len(i.partitions)-1 {
   324  		return io.EOF
   325  	}
   326  
   327  	i.currentPartition = i.partitions[i.partitionIdx]
   328  	i.outputOrderingPos = i.currentPartition.Start
   329  
   330  	var err error
   331  	for _, a := range i.w.Aggs {
   332  		err = a.startPartition(ctx, i.currentPartition, i.input)
   333  		if err != nil {
   334  			return err
   335  		}
   336  	}
   337  
   338  	return nil
   339  }
   340  
   341  func partitionsToSortFields(partitionExprs []sql.Expression) sql.SortFields {
   342  	sfs := make(sql.SortFields, len(partitionExprs))
   343  	for i, expr := range partitionExprs {
   344  		sfs[i] = sql.SortField{
   345  			Column: expr,
   346  			Order:  sql.Ascending,
   347  		}
   348  	}
   349  	return sfs
   350  }
   351  
   352  func isNewPartition(ctx *sql.Context, partitionBy []sql.Expression, last sql.Row, row sql.Row) (bool, error) {
   353  	if len(last) == 0 {
   354  		return true, nil
   355  	}
   356  
   357  	if len(partitionBy) == 0 {
   358  		return false, nil
   359  	}
   360  
   361  	lastExp, _, err := evalExprs(ctx, partitionBy, last)
   362  	if err != nil {
   363  		return false, err
   364  	}
   365  
   366  	thisExp, _, err := evalExprs(ctx, partitionBy, row)
   367  	if err != nil {
   368  		return false, err
   369  	}
   370  
   371  	for i, expr := range partitionBy {
   372  		cmp, err := expr.Type().Compare(lastExp[i], thisExp[i])
   373  		if err != nil {
   374  			return false, err
   375  		}
   376  		if cmp != 0 {
   377  			return true, nil
   378  		}
   379  	}
   380  
   381  	return false, nil
   382  }