github.com/dolthub/go-mysql-server@v0.18.0/sql/rowexec/rel_iters.go (about)

     1  // Copyright 2023 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package rowexec
    16  
    17  import (
    18  	"container/heap"
    19  	"errors"
    20  	"fmt"
    21  	"io"
    22  	"sort"
    23  	"strings"
    24  
    25  	"github.com/dolthub/jsonpath"
    26  
    27  	"github.com/dolthub/go-mysql-server/sql"
    28  	"github.com/dolthub/go-mysql-server/sql/expression"
    29  	"github.com/dolthub/go-mysql-server/sql/expression/function/aggregation"
    30  	"github.com/dolthub/go-mysql-server/sql/plan"
    31  	"github.com/dolthub/go-mysql-server/sql/types"
    32  )
    33  
    34  type topRowsIter struct {
    35  	sortFields    sql.SortFields
    36  	calcFoundRows bool
    37  	childIter     sql.RowIter
    38  	limit         int64
    39  	topRows       []sql.Row
    40  	numFoundRows  int64
    41  	idx           int
    42  }
    43  
    44  func newTopRowsIter(s sql.SortFields, limit int64, calcFoundRows bool, child sql.RowIter, childSchemaLen int) *topRowsIter {
    45  	return &topRowsIter{
    46  		sortFields:    append(s, sql.SortField{Column: expression.NewGetField(childSchemaLen, types.Int64, "order", false)}),
    47  		limit:         limit,
    48  		calcFoundRows: calcFoundRows,
    49  		childIter:     child,
    50  		idx:           -1,
    51  	}
    52  }
    53  
    54  func (i *topRowsIter) Next(ctx *sql.Context) (sql.Row, error) {
    55  	if i.idx == -1 {
    56  		err := i.computeTopRows(ctx)
    57  		if err != nil {
    58  			return nil, err
    59  		}
    60  		i.idx = 0
    61  	}
    62  
    63  	if i.idx >= len(i.topRows) {
    64  		return nil, io.EOF
    65  	}
    66  	row := i.topRows[i.idx]
    67  	i.idx++
    68  	return row[:len(row)-1], nil
    69  }
    70  
    71  func (i *topRowsIter) Close(ctx *sql.Context) error {
    72  	i.topRows = nil
    73  
    74  	if i.calcFoundRows {
    75  		ctx.SetLastQueryInfo(sql.FoundRows, i.numFoundRows)
    76  	}
    77  
    78  	return i.childIter.Close(ctx)
    79  }
    80  
    81  func (i *topRowsIter) computeTopRows(ctx *sql.Context) error {
    82  	topRowsHeap := &expression.TopRowsHeap{
    83  		expression.Sorter{
    84  			SortFields: i.sortFields,
    85  			Rows:       []sql.Row{},
    86  			LastError:  nil,
    87  			Ctx:        ctx,
    88  		},
    89  	}
    90  	for {
    91  		row, err := i.childIter.Next(ctx)
    92  		if err == io.EOF {
    93  			break
    94  		}
    95  		if err != nil {
    96  			return err
    97  		}
    98  		i.numFoundRows++
    99  
   100  		row = append(row, i.numFoundRows)
   101  
   102  		heap.Push(topRowsHeap, row)
   103  		if int64(topRowsHeap.Len()) > i.limit {
   104  			heap.Pop(topRowsHeap)
   105  		}
   106  		if topRowsHeap.LastError != nil {
   107  			return topRowsHeap.LastError
   108  		}
   109  	}
   110  
   111  	var err error
   112  	i.topRows, err = topRowsHeap.Rows()
   113  	return err
   114  }
   115  
   116  // getInt64Value returns the int64 literal value in the expression given, or an error with the errStr given if it
   117  // cannot.
   118  func getInt64Value(ctx *sql.Context, expr sql.Expression) (int64, error) {
   119  	i, err := expr.Eval(ctx, nil)
   120  	if err != nil {
   121  		return 0, err
   122  	}
   123  
   124  	switch i := i.(type) {
   125  	case int:
   126  		return int64(i), nil
   127  	case int8:
   128  		return int64(i), nil
   129  	case int16:
   130  		return int64(i), nil
   131  	case int32:
   132  		return int64(i), nil
   133  	case int64:
   134  		return i, nil
   135  	case uint:
   136  		return int64(i), nil
   137  	case uint8:
   138  		return int64(i), nil
   139  	case uint16:
   140  		return int64(i), nil
   141  	case uint32:
   142  		return int64(i), nil
   143  	case uint64:
   144  		return int64(i), nil
   145  	default:
   146  		// analyzer should catch this already
   147  		panic(fmt.Sprintf("Unsupported type for limit %T", i))
   148  	}
   149  }
   150  
   151  // windowToIter transforms a plan.Window into a series
   152  // of aggregation.WindowPartitionIter and a list of output projection indexes
   153  // for each window partition.
   154  // TODO: make partition ordering deterministic
   155  func windowToIter(w *plan.Window) ([]*aggregation.WindowPartitionIter, [][]int, error) {
   156  	partIdToOutputIdxs := make(map[uint64][]int, 0)
   157  	partIdToBlock := make(map[uint64]*aggregation.WindowPartition, 0)
   158  	var window *sql.WindowDefinition
   159  	var agg *aggregation.Aggregation
   160  	var fn sql.WindowFunction
   161  	var err error
   162  	// collect functions in hash map keyed by partitioning scheme
   163  	for i, expr := range w.SelectExprs {
   164  		if alias, ok := expr.(*expression.Alias); ok {
   165  			expr = alias.Child
   166  		}
   167  		switch e := expr.(type) {
   168  		case sql.Aggregation:
   169  			window = e.Window()
   170  			fn, err = e.NewWindowFunction()
   171  		case sql.WindowAggregation:
   172  			window = e.Window()
   173  			fn, err = e.NewWindowFunction()
   174  		default:
   175  			// non window aggregates resolve to LastAgg with empty over clause
   176  			window = sql.NewWindowDefinition(nil, nil, nil, "", "")
   177  			fn, err = aggregation.NewLast(e).NewWindowFunction()
   178  		}
   179  		if err != nil {
   180  			return nil, nil, err
   181  		}
   182  		agg = aggregation.NewAggregation(fn, fn.DefaultFramer())
   183  
   184  		id, err := window.PartitionId()
   185  		if err != nil {
   186  			return nil, nil, err
   187  		}
   188  
   189  		if block, ok := partIdToBlock[id]; !ok {
   190  			if err != nil {
   191  				return nil, nil, err
   192  			}
   193  			partIdToBlock[id] = aggregation.NewWindowPartition(
   194  				window.PartitionBy,
   195  				window.OrderBy,
   196  				[]*aggregation.Aggregation{agg},
   197  			)
   198  			partIdToOutputIdxs[id] = []int{i}
   199  		} else {
   200  			block.AddAggregation(agg)
   201  			partIdToOutputIdxs[id] = append(partIdToOutputIdxs[id], i)
   202  		}
   203  	}
   204  
   205  	// convert partition hash map into list
   206  	blockIters := make([]*aggregation.WindowPartitionIter, len(partIdToBlock))
   207  	outputOrdinals := make([][]int, len(partIdToBlock))
   208  	i := 0
   209  	for id, block := range partIdToBlock {
   210  		outputIdx := partIdToOutputIdxs[id]
   211  		blockIters[i] = aggregation.NewWindowPartitionIter(block)
   212  		outputOrdinals[i] = outputIdx
   213  		i++
   214  	}
   215  	return blockIters, outputOrdinals, nil
   216  }
   217  
   218  type offsetIter struct {
   219  	skip      int64
   220  	childIter sql.RowIter
   221  }
   222  
   223  func (i *offsetIter) Next(ctx *sql.Context) (sql.Row, error) {
   224  	if i.skip > 0 {
   225  		for i.skip > 0 {
   226  			_, err := i.childIter.Next(ctx)
   227  			if err != nil {
   228  				return nil, err
   229  			}
   230  			i.skip--
   231  		}
   232  	}
   233  
   234  	row, err := i.childIter.Next(ctx)
   235  	if err != nil {
   236  		return nil, err
   237  	}
   238  
   239  	return row, nil
   240  }
   241  
   242  func (i *offsetIter) Close(ctx *sql.Context) error {
   243  	return i.childIter.Close(ctx)
   244  }
   245  
   246  type jsonTableColOpts struct {
   247  	name      string
   248  	typ       sql.Type
   249  	forOrd    bool
   250  	exists    bool
   251  	defErrVal interface{}
   252  	defEmpVal interface{}
   253  	errOnErr  bool
   254  	errOnEmp  bool
   255  }
   256  
   257  // jsonTableCol represents a column in a json table.
   258  type jsonTableCol struct {
   259  	path string // if there are nested columns, this is a schema path, otherwise it is a col path
   260  	opts *jsonTableColOpts
   261  	cols []*jsonTableCol // nested columns
   262  
   263  	data     []interface{}
   264  	err      error
   265  	pos      int
   266  	finished bool // exhausted all rows in data
   267  	currSib  int
   268  }
   269  
   270  // IsSibling returns if the jsonTableCol contains multiple columns
   271  func (c *jsonTableCol) IsSibling() bool {
   272  	return len(c.cols) != 0
   273  }
   274  
   275  // NextSibling starts at the current sibling and moves to the next unfinished sibling
   276  // if there are no more unfinished siblings, it sets c.currSib to the first sibling and returns true
   277  // if the c.currSib is unfinished, nothing changes
   278  func (c *jsonTableCol) NextSibling() bool {
   279  	for i := c.currSib; i < len(c.cols); i++ {
   280  		if c.cols[i].IsSibling() && !c.cols[i].finished {
   281  			c.currSib = i
   282  			return false
   283  		}
   284  	}
   285  	c.currSib = 0
   286  	for i := 0; i < len(c.cols); i++ {
   287  		if c.cols[i].IsSibling() {
   288  			c.currSib = i
   289  			break
   290  		}
   291  	}
   292  	return true
   293  }
   294  
   295  // LoadData loads the data for this column from the given object and c.path
   296  // LoadData will always wrap the data in a slice to ensure it is iterable
   297  // Additionally, this function will set the c.currSib to the first sibling
   298  func (c *jsonTableCol) LoadData(obj interface{}) {
   299  	var data interface{}
   300  	data, c.err = jsonpath.JsonPathLookup(obj, c.path)
   301  	if d, ok := data.([]interface{}); ok {
   302  		c.data = d
   303  	} else {
   304  		c.data = []interface{}{data}
   305  	}
   306  	c.pos = 0
   307  
   308  	c.NextSibling()
   309  }
   310  
   311  // Reset clears the column's data and error, and recursively resets all nested columns
   312  func (c *jsonTableCol) Reset() {
   313  	c.data, c.err = nil, nil
   314  	c.finished = false
   315  	for _, col := range c.cols {
   316  		col.Reset()
   317  	}
   318  }
   319  
   320  // Next returns the next row for this column.
   321  func (c *jsonTableCol) Next(obj interface{}, pass bool, ord int) (sql.Row, error) {
   322  	// nested column should recurse
   323  	if len(c.cols) != 0 {
   324  		if c.data == nil {
   325  			c.LoadData(obj)
   326  		}
   327  
   328  		var innerObj interface{}
   329  		if !c.finished {
   330  			innerObj = c.data[c.pos]
   331  		}
   332  
   333  		var row sql.Row
   334  		for i, col := range c.cols {
   335  			innerPass := len(col.cols) != 0 && i != c.currSib
   336  			rowPart, err := col.Next(innerObj, pass || innerPass, c.pos+1)
   337  			if err != nil {
   338  				return nil, err
   339  			}
   340  			row = append(row, rowPart...)
   341  		}
   342  
   343  		if pass {
   344  			return row, nil
   345  		}
   346  
   347  		if c.NextSibling() {
   348  			for _, col := range c.cols {
   349  				col.Reset()
   350  			}
   351  			c.pos++
   352  		}
   353  
   354  		if c.pos >= len(c.data) {
   355  			c.finished = true
   356  		}
   357  
   358  		return row, nil
   359  	}
   360  
   361  	// this should only apply to nested columns, maybe...
   362  	if pass {
   363  		return sql.Row{nil}, nil
   364  	}
   365  
   366  	// FOR ORDINAL is a special case
   367  	if c.opts != nil && c.opts.forOrd {
   368  		return sql.Row{ord}, nil
   369  	}
   370  
   371  	// TODO: cache this?
   372  	val, err := jsonpath.JsonPathLookup(obj, c.path)
   373  	if c.opts.exists {
   374  		if err != nil {
   375  			return sql.Row{0}, nil
   376  		} else {
   377  			return sql.Row{1}, nil
   378  		}
   379  	}
   380  
   381  	// key error means empty
   382  	if err != nil {
   383  		if c.opts.errOnEmp {
   384  			return nil, fmt.Errorf("missing value for JSON_TABLE column '%s'", c.opts.name)
   385  		}
   386  		val = c.opts.defEmpVal
   387  	}
   388  
   389  	val, _, err = c.opts.typ.Convert(val)
   390  	if err != nil {
   391  		if c.opts.errOnErr {
   392  			return nil, err
   393  		}
   394  		val, _, err = c.opts.typ.Convert(c.opts.defErrVal)
   395  		if err != nil {
   396  			return nil, err
   397  		}
   398  	}
   399  
   400  	// Base columns are always finished
   401  	c.finished = true
   402  	return sql.Row{val}, nil
   403  }
   404  
   405  type jsonTableRowIter struct {
   406  	data    []interface{}
   407  	pos     int
   408  	cols    []*jsonTableCol
   409  	currSib int
   410  }
   411  
   412  var _ sql.RowIter = &jsonTableRowIter{}
   413  
   414  // NextSibling starts at the current sibling and moves to the next unfinished sibling
   415  // if there are no more unfinished siblings, it resets to the first sibling
   416  func (j *jsonTableRowIter) NextSibling() bool {
   417  	for i := j.currSib; i < len(j.cols); i++ {
   418  		if !j.cols[i].finished && len(j.cols[i].cols) != 0 {
   419  			j.currSib = i
   420  			return false
   421  		}
   422  	}
   423  	j.currSib = 0
   424  	for i := 0; i < len(j.cols); i++ {
   425  		if len(j.cols[i].cols) != 0 {
   426  			j.currSib = i
   427  			break
   428  		}
   429  	}
   430  	return true
   431  }
   432  
   433  func (j *jsonTableRowIter) ResetAll() {
   434  	for _, col := range j.cols {
   435  		col.Reset()
   436  	}
   437  }
   438  
   439  func (j *jsonTableRowIter) Next(ctx *sql.Context) (sql.Row, error) {
   440  	if j.pos >= len(j.data) {
   441  		return nil, io.EOF
   442  	}
   443  	obj := j.data[j.pos]
   444  
   445  	var row sql.Row
   446  	for i, col := range j.cols {
   447  		pass := len(col.cols) != 0 && i != j.currSib
   448  		rowPart, err := col.Next(obj, pass, j.pos+1)
   449  		if err != nil {
   450  			return nil, err
   451  		}
   452  		row = append(row, rowPart...)
   453  	}
   454  
   455  	if j.NextSibling() {
   456  		j.ResetAll()
   457  		j.pos++
   458  	}
   459  
   460  	return row, nil
   461  }
   462  
   463  func (j *jsonTableRowIter) Close(ctx *sql.Context) error {
   464  	return nil
   465  }
   466  
   467  // orderedDistinctIter iterates the children iterator and skips all the
   468  // repeated rows assuming the iterator has all rows sorted.
   469  type orderedDistinctIter struct {
   470  	childIter sql.RowIter
   471  	schema    sql.Schema
   472  	prevRow   sql.Row
   473  }
   474  
   475  func newOrderedDistinctIter(child sql.RowIter, schema sql.Schema) *orderedDistinctIter {
   476  	return &orderedDistinctIter{childIter: child, schema: schema}
   477  }
   478  
   479  func (di *orderedDistinctIter) Next(ctx *sql.Context) (sql.Row, error) {
   480  	for {
   481  		row, err := di.childIter.Next(ctx)
   482  		if err != nil {
   483  			return nil, err
   484  		}
   485  
   486  		if di.prevRow != nil {
   487  			ok, err := di.prevRow.Equals(row, di.schema)
   488  			if err != nil {
   489  				return nil, err
   490  			}
   491  
   492  			if ok {
   493  				continue
   494  			}
   495  		}
   496  
   497  		di.prevRow = row
   498  		return row, nil
   499  	}
   500  }
   501  
   502  func (di *orderedDistinctIter) Close(ctx *sql.Context) error {
   503  	return di.childIter.Close(ctx)
   504  }
   505  
   506  type projectIter struct {
   507  	p         []sql.Expression
   508  	childIter sql.RowIter
   509  }
   510  
   511  func (i *projectIter) Next(ctx *sql.Context) (sql.Row, error) {
   512  	childRow, err := i.childIter.Next(ctx)
   513  	if err != nil {
   514  		return nil, err
   515  	}
   516  
   517  	return ProjectRow(ctx, i.p, childRow)
   518  }
   519  
   520  func (i *projectIter) Close(ctx *sql.Context) error {
   521  	return i.childIter.Close(ctx)
   522  }
   523  
   524  // ProjectRow evaluates a set of projections.
   525  func ProjectRow(
   526  	ctx *sql.Context,
   527  	projections []sql.Expression,
   528  	row sql.Row,
   529  ) (sql.Row, error) {
   530  	var secondPass []int
   531  	var fields sql.Row
   532  	for i, expr := range projections {
   533  		// Default values that are expressions may reference other fields, thus they must evaluate after all other exprs.
   534  		// Also default expressions may not refer to other columns that come after them if they also have a default expr.
   535  		// This ensures that all columns referenced by expressions will have already been evaluated.
   536  		// Since literals do not reference other columns, they're evaluated on the first pass.
   537  		defaultVal, isDefaultVal := defaultValFromProjectExpr(expr)
   538  		if isDefaultVal && !defaultVal.IsLiteral() {
   539  			fields = append(fields, nil)
   540  			secondPass = append(secondPass, i)
   541  			continue
   542  		}
   543  		f, fErr := expr.Eval(ctx, row)
   544  		if fErr != nil {
   545  			return nil, fErr
   546  		}
   547  		f = normalizeNegativeZeros(f)
   548  		fields = append(fields, f)
   549  	}
   550  	for _, index := range secondPass {
   551  		field, err := projections[index].Eval(ctx, fields)
   552  		if err != nil {
   553  			return nil, err
   554  		}
   555  		field = normalizeNegativeZeros(field)
   556  		fields[index] = field
   557  	}
   558  	return sql.NewRow(fields...), nil
   559  }
   560  
   561  func defaultValFromProjectExpr(e sql.Expression) (*sql.ColumnDefaultValue, bool) {
   562  	if defaultVal, ok := e.(*expression.Wrapper); ok {
   563  		e = defaultVal.Unwrap()
   564  	}
   565  	if defaultVal, ok := e.(*sql.ColumnDefaultValue); ok {
   566  		return defaultVal, true
   567  	}
   568  
   569  	return nil, false
   570  }
   571  
   572  func defaultValFromSetExpression(e sql.Expression) (*sql.ColumnDefaultValue, bool) {
   573  	if sf, ok := e.(*expression.SetField); ok {
   574  		return defaultValFromProjectExpr(sf.RightChild)
   575  	}
   576  	return nil, false
   577  }
   578  
   579  // normalizeNegativeZeros converts negative zero into positive zero.
   580  // We do this so that floats and decimals have the same representation when displayed to the user.
   581  func normalizeNegativeZeros(val interface{}) interface{} {
   582  	// Golang doesn't have a negative zero literal, but negative zero compares equal to zero.
   583  	if val == float32(0) {
   584  		return float32(0)
   585  	}
   586  	if val == float64(0) {
   587  		return float64(0)
   588  	}
   589  	return val
   590  }
   591  
   592  // TODO a queue is probably more optimal
   593  type recursiveTableIter struct {
   594  	pos int
   595  	buf []sql.Row
   596  }
   597  
   598  var _ sql.RowIter = (*recursiveTableIter)(nil)
   599  
   600  func (r *recursiveTableIter) Next(ctx *sql.Context) (sql.Row, error) {
   601  	if r.buf == nil || r.pos >= len(r.buf) {
   602  		return nil, io.EOF
   603  	}
   604  	r.pos++
   605  	return r.buf[r.pos-1], nil
   606  }
   607  
   608  func (r *recursiveTableIter) Close(ctx *sql.Context) error {
   609  	r.buf = nil
   610  	return nil
   611  }
   612  
   613  func setUserVar(ctx *sql.Context, userVar *expression.UserVar, right sql.Expression, row sql.Row) error {
   614  	val, err := right.Eval(ctx, row)
   615  	if err != nil {
   616  		return err
   617  	}
   618  	typ := types.ApproximateTypeFromValue(val)
   619  
   620  	err = ctx.SetUserVariable(ctx, userVar.Name, val, typ)
   621  	if err != nil {
   622  		return err
   623  	}
   624  	return nil
   625  }
   626  
   627  func setSystemVar(ctx *sql.Context, sysVar *expression.SystemVar, right sql.Expression, row sql.Row) error {
   628  	val, err := right.Eval(ctx, row)
   629  	if err != nil {
   630  		return err
   631  	}
   632  	switch sysVar.Scope {
   633  	case sql.SystemVariableScope_Global:
   634  		err = sql.SystemVariables.SetGlobal(sysVar.Name, val)
   635  		if err != nil {
   636  			return err
   637  		}
   638  	case sql.SystemVariableScope_Session:
   639  		err = ctx.SetSessionVariable(ctx, sysVar.Name, val)
   640  		if err != nil {
   641  			return err
   642  		}
   643  	case sql.SystemVariableScope_Persist:
   644  		persistSess, ok := ctx.Session.(sql.PersistableSession)
   645  		if !ok {
   646  			return sql.ErrSessionDoesNotSupportPersistence.New()
   647  		}
   648  		err = persistSess.PersistGlobal(sysVar.Name, val)
   649  		if err != nil {
   650  			return err
   651  		}
   652  		err = sql.SystemVariables.SetGlobal(sysVar.Name, val)
   653  		if err != nil {
   654  			return err
   655  		}
   656  	case sql.SystemVariableScope_PersistOnly:
   657  		persistSess, ok := ctx.Session.(sql.PersistableSession)
   658  		if !ok {
   659  			return sql.ErrSessionDoesNotSupportPersistence.New()
   660  		}
   661  		err = persistSess.PersistGlobal(sysVar.Name, val)
   662  		if err != nil {
   663  			return err
   664  		}
   665  	case sql.SystemVariableScope_ResetPersist:
   666  		// TODO: add parser support for RESET PERSIST
   667  		persistSess, ok := ctx.Session.(sql.PersistableSession)
   668  		if !ok {
   669  			return sql.ErrSessionDoesNotSupportPersistence.New()
   670  		}
   671  		if sysVar.Name == "" {
   672  			err = persistSess.RemoveAllPersistedGlobals()
   673  		}
   674  		err = persistSess.RemovePersistedGlobal(sysVar.Name)
   675  		if err != nil {
   676  			return err
   677  		}
   678  	default: // should never be hit
   679  		return fmt.Errorf("unable to set `%s` due to unknown scope `%v`", sysVar.Name, sysVar.Scope)
   680  	}
   681  	// Setting `character_set_connection`, regardless of how it is set (directly or through SET NAMES) will also set
   682  	// `collation_connection` to the default collation for the given character set.
   683  	if strings.ToLower(sysVar.Name) == "character_set_connection" {
   684  		newSysVar := &expression.SystemVar{
   685  			Name:  "collation_connection",
   686  			Scope: sysVar.Scope,
   687  		}
   688  		if val == nil {
   689  			err = setSystemVar(ctx, newSysVar, expression.NewLiteral("", types.LongText), row)
   690  			if err != nil {
   691  				return err
   692  			}
   693  		} else {
   694  			valStr, ok := val.(string)
   695  			if !ok {
   696  				return sql.ErrInvalidSystemVariableValue.New("collation_connection", val)
   697  			}
   698  			charset, err := sql.ParseCharacterSet(valStr)
   699  			if err != nil {
   700  				return err
   701  			}
   702  			charset = charset
   703  			err = setSystemVar(ctx, newSysVar, expression.NewLiteral(charset.DefaultCollation().Name(), types.LongText), row)
   704  			if err != nil {
   705  				return err
   706  			}
   707  		}
   708  	}
   709  	return nil
   710  }
   711  
   712  // Applies the update expressions given to the row given, returning the new resultant row.
   713  func applyUpdateExpressions(ctx *sql.Context, updateExprs []sql.Expression, row sql.Row) (sql.Row, error) {
   714  	var ok bool
   715  	prev := row
   716  	for _, updateExpr := range updateExprs {
   717  		val, err := updateExpr.Eval(ctx, prev)
   718  		if err != nil {
   719  			return nil, err
   720  		}
   721  		prev, ok = val.(sql.Row)
   722  		if !ok {
   723  			return nil, plan.ErrUpdateUnexpectedSetResult.New(val)
   724  		}
   725  	}
   726  	return prev, nil
   727  }
   728  
   729  // declareVariablesIter is the sql.RowIter of *DeclareVariables.
   730  type declareVariablesIter struct {
   731  	*plan.DeclareVariables
   732  	row sql.Row
   733  }
   734  
   735  var _ sql.RowIter = (*declareVariablesIter)(nil)
   736  
   737  // Next implements the interface sql.RowIter.
   738  func (d *declareVariablesIter) Next(ctx *sql.Context) (sql.Row, error) {
   739  	defaultVal, err := d.DefaultVal.Eval(ctx, d.row)
   740  	if err != nil {
   741  		return nil, err
   742  	}
   743  	for _, varName := range d.Names {
   744  		if err := d.Pref.InitializeVariable(varName, d.Type, defaultVal); err != nil {
   745  			return nil, err
   746  		}
   747  	}
   748  	return nil, io.EOF
   749  }
   750  
   751  // Close implements the interface sql.RowIter.
   752  func (d *declareVariablesIter) Close(ctx *sql.Context) error {
   753  	return nil
   754  }
   755  
   756  // declareHandlerIter is the sql.RowIter of *DeclareHandler.
   757  type declareHandlerIter struct {
   758  	*plan.DeclareHandler
   759  }
   760  
   761  var _ sql.RowIter = (*declareHandlerIter)(nil)
   762  
   763  // Next implements the interface sql.RowIter.
   764  func (d *declareHandlerIter) Next(ctx *sql.Context) (sql.Row, error) {
   765  	d.Pref.InitializeHandler(d.Statement, d.Action, d.Condition)
   766  	return nil, io.EOF
   767  }
   768  
   769  // Close implements the interface sql.RowIter.
   770  func (d *declareHandlerIter) Close(ctx *sql.Context) error {
   771  	return nil
   772  }
   773  
   774  const cteRecursionLimit = 10001
   775  
   776  // recursiveCteIter exhaustively executes a recursive
   777  // relation [rec] populated by an [init] base case.
   778  // Refer to RecursiveCte for more details.
   779  type recursiveCteIter struct {
   780  	// base sql.Project
   781  	init sql.Node
   782  	// recursive sql.Project
   783  	rec sql.Node
   784  	// anchor to recursive table to repopulate with [temp]
   785  	working *plan.RecursiveTable
   786  	// true if UNION, false if UNION ALL
   787  	deduplicate bool
   788  	// parent iter initialization state
   789  	row sql.Row
   790  
   791  	// active iterator, either [init].RowIter or [rec].RowIter
   792  	iter sql.RowIter
   793  	// number of recursive iterations finished
   794  	cycle int
   795  	// buffer to collect intermediate results for next recursion
   796  	temp []sql.Row
   797  	// duplicate lookup if [deduplicated] set
   798  	cache sql.KeyValueCache
   799  	b     *BaseBuilder
   800  }
   801  
   802  var _ sql.RowIter = (*recursiveCteIter)(nil)
   803  
   804  // Next implements sql.RowIter
   805  func (r *recursiveCteIter) Next(ctx *sql.Context) (sql.Row, error) {
   806  	if r.iter == nil {
   807  		// start with [Init].RowIter
   808  		var err error
   809  		if r.deduplicate {
   810  			r.cache = sql.NewMapCache()
   811  
   812  		}
   813  		r.iter, err = r.b.buildNodeExec(ctx, r.init, r.row)
   814  
   815  		if err != nil {
   816  			return nil, err
   817  		}
   818  	}
   819  
   820  	var row sql.Row
   821  	for {
   822  		var err error
   823  		row, err = r.iter.Next(ctx)
   824  		if errors.Is(err, io.EOF) && len(r.temp) > 0 {
   825  			// reset [Rec].RowIter
   826  			err = r.resetIter(ctx)
   827  			if err != nil {
   828  				return nil, err
   829  			}
   830  			continue
   831  		} else if err != nil {
   832  			return nil, err
   833  		}
   834  
   835  		var key uint64
   836  		if r.deduplicate {
   837  			key, _ = sql.HashOf(row)
   838  			if k, _ := r.cache.Get(key); k != nil {
   839  				// skip duplicate
   840  				continue
   841  			}
   842  		}
   843  		r.store(row, key)
   844  		if err != nil {
   845  			return nil, err
   846  		}
   847  		break
   848  	}
   849  	return row, nil
   850  }
   851  
   852  // store saves a row to the [temp] buffer, and hashes if [deduplicated] = true
   853  func (r *recursiveCteIter) store(row sql.Row, key uint64) {
   854  	if r.deduplicate {
   855  		r.cache.Put(key, struct{}{})
   856  	}
   857  	r.temp = append(r.temp, row)
   858  	return
   859  }
   860  
   861  // resetIter creates a new [Rec].RowIter after refreshing the [working] RecursiveTable
   862  func (r *recursiveCteIter) resetIter(ctx *sql.Context) error {
   863  	if len(r.temp) == 0 {
   864  		return io.EOF
   865  	}
   866  	r.cycle++
   867  	if r.cycle > cteRecursionLimit {
   868  		return sql.ErrCteRecursionLimitExceeded.New()
   869  	}
   870  
   871  	if r.working != nil {
   872  		r.working.Buf = r.temp
   873  		r.temp = make([]sql.Row, 0)
   874  	}
   875  
   876  	err := r.iter.Close(ctx)
   877  	if err != nil {
   878  		return err
   879  	}
   880  	r.iter, err = r.b.buildNodeExec(ctx, r.rec, r.row)
   881  	if err != nil {
   882  		return err
   883  	}
   884  	return nil
   885  }
   886  
   887  // Close implements sql.RowIter
   888  func (r *recursiveCteIter) Close(ctx *sql.Context) error {
   889  	r.working.Buf = nil
   890  	r.temp = nil
   891  	if r.iter != nil {
   892  		return r.iter.Close(ctx)
   893  	}
   894  	return nil
   895  }
   896  
   897  type limitIter struct {
   898  	calcFoundRows bool
   899  	currentPos    int64
   900  	childIter     sql.RowIter
   901  	limit         int64
   902  }
   903  
   904  func (li *limitIter) Next(ctx *sql.Context) (sql.Row, error) {
   905  	if li.currentPos >= li.limit {
   906  		// If we were asked to calc all found rows, then when we are past the limit we iterate over the rest of the
   907  		// result set to count it
   908  		if li.calcFoundRows {
   909  			for {
   910  				_, err := li.childIter.Next(ctx)
   911  				if err != nil {
   912  					return nil, err
   913  				}
   914  				li.currentPos++
   915  			}
   916  		}
   917  
   918  		return nil, io.EOF
   919  	}
   920  
   921  	childRow, err := li.childIter.Next(ctx)
   922  	if err != nil {
   923  		return nil, err
   924  	}
   925  	li.currentPos++
   926  
   927  	return childRow, nil
   928  }
   929  
   930  func (li *limitIter) Close(ctx *sql.Context) error {
   931  	err := li.childIter.Close(ctx)
   932  	if err != nil {
   933  		return err
   934  	}
   935  
   936  	if li.calcFoundRows {
   937  		ctx.SetLastQueryInfo(sql.FoundRows, li.currentPos)
   938  	}
   939  	return nil
   940  }
   941  
   942  type sortIter struct {
   943  	sortFields sql.SortFields
   944  	childIter  sql.RowIter
   945  	sortedRows []sql.Row
   946  	idx        int
   947  }
   948  
   949  var _ sql.RowIter = (*sortIter)(nil)
   950  
   951  func newSortIter(s sql.SortFields, child sql.RowIter) *sortIter {
   952  	return &sortIter{
   953  		sortFields: s,
   954  		childIter:  child,
   955  		idx:        -1,
   956  	}
   957  }
   958  
   959  func (i *sortIter) Next(ctx *sql.Context) (sql.Row, error) {
   960  	if i.idx == -1 {
   961  		err := i.computeSortedRows(ctx)
   962  		if err != nil {
   963  			return nil, err
   964  		}
   965  		i.idx = 0
   966  	}
   967  
   968  	if i.idx >= len(i.sortedRows) {
   969  		return nil, io.EOF
   970  	}
   971  	row := i.sortedRows[i.idx]
   972  	i.idx++
   973  	return row, nil
   974  }
   975  
   976  func (i *sortIter) Close(ctx *sql.Context) error {
   977  	i.sortedRows = nil
   978  	return i.childIter.Close(ctx)
   979  }
   980  
   981  func (i *sortIter) computeSortedRows(ctx *sql.Context) error {
   982  	cache, dispose := ctx.Memory.NewRowsCache()
   983  	defer dispose()
   984  
   985  	for {
   986  		row, err := i.childIter.Next(ctx)
   987  
   988  		if err == io.EOF {
   989  			break
   990  		}
   991  		if err != nil {
   992  			return err
   993  		}
   994  
   995  		if err := cache.Add(row); err != nil {
   996  			return err
   997  		}
   998  	}
   999  
  1000  	rows := cache.Get()
  1001  	sorter := &expression.Sorter{
  1002  		SortFields: i.sortFields,
  1003  		Rows:       rows,
  1004  		LastError:  nil,
  1005  		Ctx:        ctx,
  1006  	}
  1007  	sort.Stable(sorter)
  1008  	if sorter.LastError != nil {
  1009  		return sorter.LastError
  1010  	}
  1011  	i.sortedRows = rows
  1012  	return nil
  1013  }
  1014  
  1015  // distinctIter keeps track of the hashes of all rows that have been emitted.
  1016  // It does not emit any rows whose hashes have been seen already.
  1017  // TODO: come up with a way to use less memory than keeping all hashes in memory.
  1018  // Even though they are just 64-bit integers, this could be a problem in large
  1019  // result sets.
  1020  type distinctIter struct {
  1021  	childIter sql.RowIter
  1022  	seen      sql.KeyValueCache
  1023  	dispose   sql.DisposeFunc
  1024  }
  1025  
  1026  func newDistinctIter(ctx *sql.Context, child sql.RowIter) *distinctIter {
  1027  	cache, dispose := ctx.Memory.NewHistoryCache()
  1028  	return &distinctIter{
  1029  		childIter: child,
  1030  		seen:      cache,
  1031  		dispose:   dispose,
  1032  	}
  1033  }
  1034  
  1035  func (di *distinctIter) Next(ctx *sql.Context) (sql.Row, error) {
  1036  	for {
  1037  		row, err := di.childIter.Next(ctx)
  1038  		if err != nil {
  1039  			if err == io.EOF {
  1040  				di.Dispose()
  1041  			}
  1042  			return nil, err
  1043  		}
  1044  
  1045  		hash, err := sql.HashOf(row)
  1046  		if err != nil {
  1047  			return nil, err
  1048  		}
  1049  
  1050  		if _, err := di.seen.Get(hash); err == nil {
  1051  			continue
  1052  		}
  1053  
  1054  		if err := di.seen.Put(hash, struct{}{}); err != nil {
  1055  			return nil, err
  1056  		}
  1057  
  1058  		return row, nil
  1059  	}
  1060  }
  1061  
  1062  func (di *distinctIter) Close(ctx *sql.Context) error {
  1063  	di.Dispose()
  1064  	return di.childIter.Close(ctx)
  1065  }
  1066  
  1067  func (di *distinctIter) Dispose() {
  1068  	if di.dispose != nil {
  1069  		di.dispose()
  1070  	}
  1071  }
  1072  
  1073  type unionIter struct {
  1074  	cur      sql.RowIter
  1075  	nextIter func(ctx *sql.Context) (sql.RowIter, error)
  1076  }
  1077  
  1078  func (ui *unionIter) Next(ctx *sql.Context) (sql.Row, error) {
  1079  	res, err := ui.cur.Next(ctx)
  1080  	if err == io.EOF {
  1081  		if ui.nextIter == nil {
  1082  			return nil, io.EOF
  1083  		}
  1084  		err = ui.cur.Close(ctx)
  1085  		if err != nil {
  1086  			return nil, err
  1087  		}
  1088  		ui.cur, err = ui.nextIter(ctx)
  1089  		ui.nextIter = nil
  1090  		if err != nil {
  1091  			return nil, err
  1092  		}
  1093  		return ui.cur.Next(ctx)
  1094  	}
  1095  	return res, err
  1096  }
  1097  
  1098  func (ui *unionIter) Close(ctx *sql.Context) error {
  1099  	if ui.cur != nil {
  1100  		return ui.cur.Close(ctx)
  1101  	} else {
  1102  		return nil
  1103  	}
  1104  }
  1105  
  1106  type intersectIter struct {
  1107  	lIter, rIter sql.RowIter
  1108  	cached       bool
  1109  	cache        map[uint64]int
  1110  }
  1111  
  1112  func (ii *intersectIter) Next(ctx *sql.Context) (sql.Row, error) {
  1113  	if !ii.cached {
  1114  		ii.cache = make(map[uint64]int)
  1115  		for {
  1116  			res, err := ii.rIter.Next(ctx)
  1117  			if err != nil && err != io.EOF {
  1118  				return nil, err
  1119  			}
  1120  
  1121  			hash, herr := sql.HashOf(res)
  1122  			if herr != nil {
  1123  				return nil, herr
  1124  			}
  1125  			if _, ok := ii.cache[hash]; !ok {
  1126  				ii.cache[hash] = 0
  1127  			}
  1128  			ii.cache[hash]++
  1129  
  1130  			if err == io.EOF {
  1131  				break
  1132  			}
  1133  		}
  1134  		ii.cached = true
  1135  	}
  1136  
  1137  	for {
  1138  		res, err := ii.lIter.Next(ctx)
  1139  		if err != nil {
  1140  			return nil, err
  1141  		}
  1142  
  1143  		hash, herr := sql.HashOf(res)
  1144  		if herr != nil {
  1145  			return nil, herr
  1146  		}
  1147  		if _, ok := ii.cache[hash]; !ok {
  1148  			continue
  1149  		}
  1150  		if ii.cache[hash] <= 0 {
  1151  			continue
  1152  		}
  1153  		ii.cache[hash]--
  1154  
  1155  		return res, nil
  1156  	}
  1157  }
  1158  
  1159  func (ii *intersectIter) Close(ctx *sql.Context) error {
  1160  	if ii.lIter != nil {
  1161  		if err := ii.lIter.Close(ctx); err != nil {
  1162  			return err
  1163  		}
  1164  	}
  1165  	if ii.rIter != nil {
  1166  		if err := ii.rIter.Close(ctx); err != nil {
  1167  			return err
  1168  		}
  1169  	}
  1170  	return nil
  1171  }
  1172  
  1173  type exceptIter struct {
  1174  	lIter, rIter sql.RowIter
  1175  	cached       bool
  1176  	cache        map[uint64]int
  1177  }
  1178  
  1179  func (ei *exceptIter) Next(ctx *sql.Context) (sql.Row, error) {
  1180  	if !ei.cached {
  1181  		ei.cache = make(map[uint64]int)
  1182  		for {
  1183  			res, err := ei.rIter.Next(ctx)
  1184  			if err != nil && err != io.EOF {
  1185  				return nil, err
  1186  			}
  1187  
  1188  			hash, herr := sql.HashOf(res)
  1189  			if herr != nil {
  1190  				return nil, herr
  1191  			}
  1192  			if _, ok := ei.cache[hash]; !ok {
  1193  				ei.cache[hash] = 0
  1194  			}
  1195  			ei.cache[hash]++
  1196  
  1197  			if err == io.EOF {
  1198  				break
  1199  			}
  1200  		}
  1201  		ei.cached = true
  1202  	}
  1203  
  1204  	for {
  1205  		res, err := ei.lIter.Next(ctx)
  1206  		if err != nil {
  1207  			return nil, err
  1208  		}
  1209  
  1210  		hash, herr := sql.HashOf(res)
  1211  		if herr != nil {
  1212  			return nil, herr
  1213  		}
  1214  		if _, ok := ei.cache[hash]; !ok {
  1215  			return res, nil
  1216  		}
  1217  		if ei.cache[hash] <= 0 {
  1218  			return res, nil
  1219  		}
  1220  		ei.cache[hash]--
  1221  	}
  1222  }
  1223  
  1224  func (ei *exceptIter) Close(ctx *sql.Context) error {
  1225  	if ei.lIter != nil {
  1226  		if err := ei.lIter.Close(ctx); err != nil {
  1227  			return err
  1228  		}
  1229  	}
  1230  	if ei.rIter != nil {
  1231  		if err := ei.rIter.Close(ctx); err != nil {
  1232  			return err
  1233  		}
  1234  	}
  1235  	return nil
  1236  }