github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/aggregation/window_iter.go (about)

     1  // Copyright 2022 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package aggregation
    16  
    17  import (
    18  	"errors"
    19  	"io"
    20  
    21  	"github.com/dolthub/go-mysql-server/sql"
    22  )
    23  
    24  // WindowIter is a wrapper that evaluates a set of WindowPartitionIter.
    25  //
    26  // The current implementation has 3 steps:
    27  // 1. Materialize [iter] and duplicate a sql.WindowBuffer for each partition.
    28  // 2. Collect rows from child partitions.
    29  // 3. Rearrange partition results into the projected ordering given by [outputOrdinals].
    30  //
    31  // We assume [outputOrdinals] is appropriately sized for [partitionIters].
    32  type WindowIter struct {
    33  	partitionIters []*WindowPartitionIter
    34  	outputOrdinals [][]int
    35  	iter           sql.RowIter
    36  	initialized    bool
    37  }
    38  
    39  func NewWindowIter(partitionIters []*WindowPartitionIter, outputOrdinals [][]int, iter sql.RowIter) *WindowIter {
    40  	return &WindowIter{
    41  		partitionIters: partitionIters,
    42  		outputOrdinals: outputOrdinals,
    43  		iter:           iter,
    44  	}
    45  }
    46  
    47  var _ sql.RowIter = (*WindowIter)(nil)
    48  var _ sql.Disposable = (*WindowIter)(nil)
    49  
    50  // Close implements sql.RowIter
    51  func (i *WindowIter) Close(ctx *sql.Context) error {
    52  	i.Dispose()
    53  	var err error
    54  	for _, p := range i.partitionIters {
    55  		e := p.Close(ctx)
    56  		if err == nil && e != nil {
    57  			err = e
    58  		}
    59  	}
    60  	return err
    61  }
    62  
    63  // Dispose implements sql.Disposable
    64  func (i *WindowIter) Dispose() {
    65  	for _, p := range i.partitionIters {
    66  		p.Dispose()
    67  	}
    68  	return
    69  }
    70  
    71  // Next implements sql.RowIter
    72  func (i *WindowIter) Next(ctx *sql.Context) (sql.Row, error) {
    73  	if !i.initialized {
    74  		err := i.initializeIters(ctx)
    75  		if err != nil {
    76  			return nil, err
    77  		}
    78  	}
    79  
    80  	row := make(sql.Row, i.size())
    81  	for j, pIter := range i.partitionIters {
    82  		res, err := pIter.Next(ctx)
    83  		if err != nil {
    84  			return nil, err
    85  		}
    86  		for k, idx := range i.outputOrdinals[j] {
    87  			row[idx] = res[k]
    88  		}
    89  	}
    90  	return row, nil
    91  }
    92  
    93  func (i *WindowIter) size() int {
    94  	size := -1
    95  	for _, i := range i.outputOrdinals {
    96  		for _, j := range i {
    97  			if j > size {
    98  				size = j
    99  			}
   100  		}
   101  	}
   102  	return size + 1
   103  }
   104  
   105  // initializeIters materializes and copies the input buffer into each
   106  // WindowPartitionIter.
   107  // TODO: share the child buffer and sort/partition inbetween WindowPartitionIters
   108  func (i *WindowIter) initializeIters(ctx *sql.Context) error {
   109  	buf := make(sql.WindowBuffer, 0)
   110  	var row sql.Row
   111  	var err error
   112  	for {
   113  		// drain child iter into reusable buffer
   114  		row, err = i.iter.Next(ctx)
   115  		if errors.Is(err, io.EOF) {
   116  			break
   117  		}
   118  		if err != nil {
   119  			return err
   120  		}
   121  		buf = append(buf, row)
   122  	}
   123  
   124  	for _, i := range i.partitionIters {
   125  		// each iter has its own copy of input buffer
   126  		i.child = &windowBufferIter{buf: buf}
   127  	}
   128  	i.initialized = true
   129  	return nil
   130  }
   131  
   132  // windowBufferIter bridges an in-memory buffer to the sql.RowIter interface
   133  type windowBufferIter struct {
   134  	buf sql.WindowBuffer
   135  	pos int
   136  }
   137  
   138  func (i *windowBufferIter) Next(ctx *sql.Context) (sql.Row, error) {
   139  	if i.pos >= len(i.buf) {
   140  		return nil, io.EOF
   141  	}
   142  	row := i.buf[i.pos]
   143  	i.pos++
   144  	return row, nil
   145  }
   146  
   147  func (i *windowBufferIter) Close(ctx *sql.Context) error {
   148  	i.buf = nil
   149  	return nil
   150  }