github.com/dolthub/go-mysql-server@v0.18.0/sql/planbuilder/aggregates.go (about)

     1  // Copyright 2023 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package planbuilder
    16  
    17  import (
    18  	"fmt"
    19  	"sort"
    20  	"strings"
    21  
    22  	ast "github.com/dolthub/vitess/go/vt/sqlparser"
    23  
    24  	"github.com/dolthub/go-mysql-server/sql"
    25  	"github.com/dolthub/go-mysql-server/sql/expression"
    26  	"github.com/dolthub/go-mysql-server/sql/expression/function/aggregation"
    27  	"github.com/dolthub/go-mysql-server/sql/plan"
    28  	"github.com/dolthub/go-mysql-server/sql/transform"
    29  	"github.com/dolthub/go-mysql-server/sql/types"
    30  )
    31  
    32  var _ ast.Expr = (*aggregateInfo)(nil)
    33  
    34  type groupBy struct {
    35  	inCols   []scopeColumn
    36  	outScope *scope
    37  	aggs     map[string]scopeColumn
    38  	grouping map[string]bool
    39  }
    40  
    41  func (g *groupBy) addInCol(c scopeColumn) {
    42  	g.inCols = append(g.inCols, c)
    43  }
    44  
    45  func (g *groupBy) addOutCol(c scopeColumn) columnId {
    46  	return g.outScope.newColumn(c)
    47  }
    48  
    49  func (g *groupBy) hasAggs() bool {
    50  	return len(g.aggs) > 0
    51  }
    52  
    53  func (g *groupBy) aggregations() []scopeColumn {
    54  	aggregations := make([]scopeColumn, 0, len(g.aggs))
    55  	for _, agg := range g.aggs {
    56  		aggregations = append(aggregations, agg)
    57  	}
    58  	sort.Slice(aggregations, func(i, j int) bool {
    59  		return aggregations[i].scalar.String() < aggregations[j].scalar.String()
    60  	})
    61  	return aggregations
    62  }
    63  
    64  func (g *groupBy) addAggStr(c scopeColumn) {
    65  	if g.aggs == nil {
    66  		g.aggs = make(map[string]scopeColumn)
    67  	}
    68  	g.aggs[strings.ToLower(c.scalar.String())] = c
    69  }
    70  
    71  func (g *groupBy) getAggRef(name string) sql.Expression {
    72  	if g.aggs == nil {
    73  		return nil
    74  	}
    75  	ret, _ := g.aggs[name]
    76  	if ret.empty() {
    77  		return nil
    78  	}
    79  	return ret.scalarGf()
    80  }
    81  
    82  type aggregateInfo struct {
    83  	ast.Expr
    84  }
    85  
    86  func (b *Builder) needsAggregation(fromScope *scope, sel *ast.Select) bool {
    87  	return len(sel.GroupBy) > 0 ||
    88  		(fromScope.groupBy != nil && fromScope.groupBy.hasAggs())
    89  }
    90  
    91  func (b *Builder) buildGroupingCols(fromScope, projScope *scope, groupby ast.GroupBy, selects ast.SelectExprs) []sql.Expression {
    92  	// grouping col will either be:
    93  	// 1) alias into targets
    94  	// 2) a column reference
    95  	// 3) an index into selects
    96  	// 4) a simple non-aggregate expression
    97  	groupings := make([]sql.Expression, 0)
    98  	if fromScope.groupBy == nil {
    99  		fromScope.initGroupBy()
   100  	}
   101  	g := fromScope.groupBy
   102  	for _, e := range groupby {
   103  		var col scopeColumn
   104  		switch e := e.(type) {
   105  		case *ast.ColName:
   106  			var ok bool
   107  			// GROUP BY binds to column references before projections.
   108  			dbName := strings.ToLower(e.Qualifier.Qualifier.String())
   109  			tblName := strings.ToLower(e.Qualifier.Name.String())
   110  			colName := strings.ToLower(e.Name.String())
   111  			col, ok = fromScope.resolveColumn(dbName, tblName, colName, true, false)
   112  			if !ok {
   113  				col, ok = projScope.resolveColumn(dbName, tblName, colName, true, true)
   114  			}
   115  
   116  			if !ok {
   117  				b.handleErr(sql.ErrColumnNotFound.New(e.Name.String()))
   118  			}
   119  		case *ast.SQLVal:
   120  			// literal -> index into targets
   121  			replace := b.normalizeValArg(e)
   122  			val, ok := replace.(*ast.SQLVal)
   123  			if !ok {
   124  				// ast.NullVal
   125  				continue
   126  			}
   127  			if val.Type == ast.IntVal {
   128  				lit := b.convertInt(string(val.Val), 10)
   129  				idx, _, err := types.Int64.Convert(lit.Value())
   130  				if err != nil {
   131  					b.handleErr(err)
   132  				}
   133  				intIdx, ok := idx.(int64)
   134  				if !ok {
   135  					b.handleErr(fmt.Errorf("expected integer order by literal"))
   136  				}
   137  				if intIdx < 1 {
   138  					b.handleErr(fmt.Errorf("expected positive integer order by literal"))
   139  				}
   140  				col = projScope.cols[intIdx-1]
   141  			}
   142  		default:
   143  			expr := b.buildScalar(fromScope, e)
   144  			col = scopeColumn{
   145  				col:      expr.String(),
   146  				typ:      nil,
   147  				scalar:   expr,
   148  				nullable: expr.IsNullable(),
   149  			}
   150  		}
   151  		if col.scalar == nil {
   152  			gf := expression.NewGetFieldWithTable(int(col.id), int(col.tableId), col.typ, col.db, col.table, col.col, col.nullable)
   153  			id, ok := fromScope.getExpr(gf.String(), true)
   154  			if !ok {
   155  				err := sql.ErrColumnNotFound.New(gf.String())
   156  				b.handleErr(err)
   157  			}
   158  			col.scalar = gf.WithIndex(int(id))
   159  		}
   160  		g.addInCol(col)
   161  		groupings = append(groupings, col.scalar)
   162  	}
   163  
   164  	return groupings
   165  }
   166  
   167  func (b *Builder) buildAggregation(fromScope, projScope *scope, groupingCols []sql.Expression) *scope {
   168  	// GROUP_BY consists of:
   169  	// - input arguments projection
   170  	// - grouping cols projection
   171  	// - aggregate expressions
   172  	// - output projection
   173  	if fromScope.groupBy == nil {
   174  		fromScope.initGroupBy()
   175  	}
   176  
   177  	group := fromScope.groupBy
   178  	outScope := group.outScope
   179  	// select columns:
   180  	//  - aggs
   181  	//  - extra columns needed by having, order by, select
   182  	var selectExprs []sql.Expression
   183  	var selectGfs []sql.Expression
   184  	selectStr := make(map[string]bool)
   185  	for _, e := range group.aggregations() {
   186  		if !selectStr[strings.ToLower(e.String())] {
   187  			selectExprs = append(selectExprs, e.scalar)
   188  			selectGfs = append(selectGfs, e.scalarGf())
   189  			selectStr[strings.ToLower(e.String())] = true
   190  		}
   191  	}
   192  	var aliases []sql.Expression
   193  	for _, col := range projScope.cols {
   194  		// eval aliases in project scope
   195  		switch e := col.scalar.(type) {
   196  		case *expression.Alias:
   197  			if !e.Unreferencable() {
   198  				aliases = append(aliases, e.WithId(sql.ColumnId(col.id)).(*expression.Alias))
   199  			}
   200  		default:
   201  		}
   202  
   203  		// projection dependencies -> table cols needed above
   204  		transform.InspectExpr(col.scalar, func(e sql.Expression) bool {
   205  			switch e := e.(type) {
   206  			case *expression.GetField:
   207  				colName := strings.ToLower(e.String())
   208  				if !selectStr[colName] {
   209  					selectExprs = append(selectExprs, e)
   210  					selectGfs = append(selectGfs, e)
   211  					selectStr[colName] = true
   212  				}
   213  			default:
   214  			}
   215  			return false
   216  		})
   217  	}
   218  	for _, e := range fromScope.extraCols {
   219  		// accessory cols used by ORDER_BY, HAVING
   220  		if !selectStr[e.String()] {
   221  			selectExprs = append(selectExprs, e.scalarGf())
   222  			selectGfs = append(selectGfs, e.scalarGf())
   223  
   224  			selectStr[e.String()] = true
   225  		}
   226  	}
   227  	gb := plan.NewGroupBy(selectExprs, groupingCols, fromScope.node)
   228  	outScope.node = gb
   229  
   230  	if len(aliases) > 0 {
   231  		outScope.node = plan.NewProject(append(selectGfs, aliases...), outScope.node)
   232  	}
   233  	return outScope
   234  }
   235  
   236  func isAggregateFunc(name string) bool {
   237  	switch name {
   238  	case "avg", "bit_and", "bit_or", "bit_xor", "count",
   239  		"group_concat", "json_arrayagg", "json_objectagg",
   240  		"max", "min", "std", "stddev_pop", "stddev_samp",
   241  		"stddev", "sum", "var_pop", "var_samp", "variance",
   242  		"first", "last", "any_value":
   243  		return true
   244  	default:
   245  		return false
   246  	}
   247  }
   248  
   249  // buildAggregateFunc tags aggregate functions in the correct scope
   250  // and makes the aggregate available for reference by other clauses.
   251  func (b *Builder) buildAggregateFunc(inScope *scope, name string, e *ast.FuncExpr) sql.Expression {
   252  	if len(inScope.windowFuncs) > 0 {
   253  		err := sql.ErrNonAggregatedColumnWithoutGroupBy.New()
   254  		b.handleErr(err)
   255  	}
   256  
   257  	if inScope.groupBy == nil {
   258  		inScope.initGroupBy()
   259  	}
   260  	gb := inScope.groupBy
   261  
   262  	if name == "count" {
   263  		if _, ok := e.Exprs[0].(*ast.StarExpr); ok {
   264  			var agg sql.Aggregation
   265  			if e.Distinct {
   266  				agg = aggregation.NewCountDistinct(expression.NewLiteral(1, types.Int64))
   267  			} else {
   268  				agg = aggregation.NewCount(expression.NewLiteral(1, types.Int64))
   269  			}
   270  			aggName := strings.ToLower(agg.String())
   271  			gf := gb.getAggRef(aggName)
   272  			if gf != nil {
   273  				// if we've already computed use reference here
   274  				return gf
   275  			}
   276  
   277  			col := scopeColumn{col: strings.ToLower(agg.String()), scalar: agg, typ: agg.Type(), nullable: agg.IsNullable()}
   278  			id := gb.outScope.newColumn(col)
   279  			col.id = id
   280  
   281  			agg = agg.WithId(sql.ColumnId(id)).(sql.Aggregation)
   282  			gb.outScope.cols[len(gb.outScope.cols)-1].scalar = agg
   283  			col.scalar = agg
   284  
   285  			gb.addAggStr(col)
   286  			return col.scalarGf()
   287  		}
   288  	}
   289  
   290  	if name == "jsonarray" {
   291  		// TODO we don't have any tests for this
   292  		if _, ok := e.Exprs[0].(*ast.StarExpr); ok {
   293  			var agg sql.Aggregation
   294  			agg = aggregation.NewJsonArray(expression.NewLiteral(expression.NewStar(), types.Int64))
   295  			//if e.Distinct {
   296  			//	agg = plan.NewDistinct(expression.NewLiteral(1, types.Int64))
   297  			//}
   298  			aggName := strings.ToLower(agg.String())
   299  			gf := gb.getAggRef(aggName)
   300  			if gf != nil {
   301  				// if we've already computed use reference here
   302  				return gf
   303  			}
   304  
   305  			col := scopeColumn{col: strings.ToLower(agg.String()), scalar: agg, typ: agg.Type(), nullable: agg.IsNullable()}
   306  			id := gb.outScope.newColumn(col)
   307  
   308  			agg = agg.WithId(sql.ColumnId(id)).(*aggregation.JsonArray)
   309  			gb.outScope.cols[len(gb.outScope.cols)-1].scalar = agg
   310  			col.scalar = agg
   311  
   312  			col.id = id
   313  			gb.addAggStr(col)
   314  			return col.scalarGf()
   315  		}
   316  	}
   317  
   318  	var args []sql.Expression
   319  	for _, arg := range e.Exprs {
   320  		e := b.selectExprToExpression(inScope, arg)
   321  		switch e := e.(type) {
   322  		case *expression.GetField:
   323  			if e.TableId() == 0 {
   324  				// TODO: not sure where this came from but it's not true
   325  				// aliases are not valid aggregate arguments, the alias must be masking a column
   326  				gf := b.selectExprToExpression(inScope.parent, arg)
   327  				var ok bool
   328  				e, ok = gf.(*expression.GetField)
   329  				if !ok || e.TableId() == 0 {
   330  					b.handleErr(fmt.Errorf("failed to resolve aggregate column argument: %s", gf))
   331  				}
   332  			}
   333  			args = append(args, e)
   334  			col := scopeColumn{tableId: e.TableID(), db: e.Database(), table: e.Table(), col: e.Name(), scalar: e, typ: e.Type(), nullable: e.IsNullable()}
   335  			gb.addInCol(col)
   336  		case *expression.Star:
   337  			err := sql.ErrStarUnsupported.New()
   338  			b.handleErr(err)
   339  		case *plan.Subquery:
   340  			args = append(args, e)
   341  			col := scopeColumn{col: e.QueryString, scalar: e, typ: e.Type()}
   342  			gb.addInCol(col)
   343  		default:
   344  			args = append(args, e)
   345  			col := scopeColumn{col: e.String(), scalar: e, typ: e.Type()}
   346  			gb.addInCol(col)
   347  		}
   348  	}
   349  
   350  	var agg sql.Aggregation
   351  	if e.Distinct && name == "count" {
   352  		agg = aggregation.NewCountDistinct(args...)
   353  	} else {
   354  
   355  		// NOTE: Not all aggregate functions support DISTINCT. Fortunately, the vitess parser will throw
   356  		// errors for when DISTINCT is used on aggregate functions that don't support DISTINCT.
   357  		if e.Distinct {
   358  			if len(e.Exprs) != 1 {
   359  				err := sql.ErrUnsupportedSyntax.New("more than one expression with distinct")
   360  				b.handleErr(err)
   361  			}
   362  
   363  			args[0] = expression.NewDistinctExpression(args[0])
   364  		}
   365  
   366  		f, err := b.cat.Function(b.ctx, name)
   367  		if err != nil {
   368  			b.handleErr(err)
   369  		}
   370  
   371  		newInst, err := f.NewInstance(args)
   372  		if err != nil {
   373  			b.handleErr(err)
   374  		}
   375  		var ok bool
   376  		agg, ok = newInst.(sql.Aggregation)
   377  		if !ok {
   378  			err := fmt.Errorf("expected function to be aggregation: %s", f.FunctionName())
   379  			b.handleErr(err)
   380  		}
   381  	}
   382  
   383  	aggType := agg.Type()
   384  	if name == "avg" || name == "sum" {
   385  		aggType = types.Float64
   386  	}
   387  
   388  	aggName := strings.ToLower(plan.AliasSubqueryString(agg))
   389  	if id, ok := gb.outScope.getExpr(aggName, true); ok {
   390  		// if we've already computed use reference here
   391  		gf := expression.NewGetFieldWithTable(int(id), 0, aggType, "", "", aggName, agg.IsNullable())
   392  		return gf
   393  	}
   394  
   395  	col := scopeColumn{col: aggName, scalar: agg, typ: aggType, nullable: agg.IsNullable()}
   396  	id := gb.outScope.newColumn(col)
   397  
   398  	agg = agg.WithId(sql.ColumnId(id)).(sql.Aggregation)
   399  	gb.outScope.cols[len(gb.outScope.cols)-1].scalar = agg
   400  	col.scalar = agg
   401  
   402  	col.id = id
   403  	gb.addAggStr(col)
   404  	return col.scalarGf()
   405  }
   406  
   407  func (b *Builder) buildGroupConcat(inScope *scope, e *ast.GroupConcatExpr) sql.Expression {
   408  	if inScope.groupBy == nil {
   409  		inScope.initGroupBy()
   410  	}
   411  	gb := inScope.groupBy
   412  
   413  	args := make([]sql.Expression, len(e.Exprs))
   414  	for i, a := range e.Exprs {
   415  		args[i] = b.selectExprToExpression(inScope, a)
   416  	}
   417  
   418  	separatorS := ","
   419  	if !e.Separator.DefaultSeparator {
   420  		separatorS = e.Separator.SeparatorString
   421  	}
   422  
   423  	orderByScope := b.analyzeOrderBy(inScope, inScope, e.OrderBy)
   424  	var sortFields sql.SortFields
   425  	for _, c := range orderByScope.cols {
   426  		so := sql.Ascending
   427  		if c.descending {
   428  			so = sql.Descending
   429  		}
   430  		scalar := c.scalar
   431  		if scalar == nil {
   432  			scalar = c.scalarGf()
   433  		}
   434  		sf := sql.SortField{
   435  			Column: scalar,
   436  			Order:  so,
   437  		}
   438  		sortFields = append(sortFields, sf)
   439  	}
   440  
   441  	//TODO: this should be acquired at runtime, not at parse time, so fix this
   442  	gcml, err := b.ctx.GetSessionVariable(b.ctx, "group_concat_max_len")
   443  	if err != nil {
   444  		b.handleErr(err)
   445  	}
   446  	groupConcatMaxLen := gcml.(uint64)
   447  
   448  	// todo store ref to aggregate
   449  	agg := aggregation.NewGroupConcat(e.Distinct, sortFields, separatorS, args, int(groupConcatMaxLen))
   450  	aggName := strings.ToLower(plan.AliasSubqueryString(agg))
   451  	col := scopeColumn{col: aggName, scalar: agg, typ: agg.Type(), nullable: agg.IsNullable()}
   452  
   453  	id := gb.outScope.newColumn(col)
   454  
   455  	agg = agg.WithId(sql.ColumnId(id)).(*aggregation.GroupConcat)
   456  	gb.outScope.cols[len(gb.outScope.cols)-1].scalar = agg
   457  	col.scalar = agg
   458  
   459  	gb.addAggStr(col)
   460  	col.id = id
   461  	return col.scalarGf()
   462  }
   463  
   464  func isWindowFunc(name string) bool {
   465  	switch name {
   466  	case "first", "last", "count", "sum", "any_value",
   467  		"avg", "max", "min", "count_distinct", "json_arrayagg",
   468  		"row_number", "percent_rank", "lead", "lag",
   469  		"first_value", "last_value",
   470  		"rank", "dense_rank":
   471  		return true
   472  	default:
   473  		return false
   474  	}
   475  }
   476  
   477  func (b *Builder) buildWindowFunc(inScope *scope, name string, e *ast.FuncExpr, over *ast.WindowDef) sql.Expression {
   478  	if inScope.groupBy != nil {
   479  		err := sql.ErrNonAggregatedColumnWithoutGroupBy.New()
   480  		b.handleErr(err)
   481  	}
   482  
   483  	// internal expressions can be complex, but window can't be more than alias
   484  	var args []sql.Expression
   485  	for _, arg := range e.Exprs {
   486  		e := b.selectExprToExpression(inScope, arg)
   487  		args = append(args, e)
   488  	}
   489  
   490  	var win sql.WindowAdaptableExpression
   491  	if name == "count" {
   492  		if _, ok := e.Exprs[0].(*ast.StarExpr); ok {
   493  			win = aggregation.NewCount(expression.NewLiteral(1, types.Int64))
   494  		}
   495  	}
   496  	if win == nil {
   497  		f, err := b.cat.Function(b.ctx, name)
   498  		if err != nil {
   499  			b.handleErr(err)
   500  		}
   501  
   502  		newInst, err := f.NewInstance(args)
   503  		var ok bool
   504  		win, ok = newInst.(sql.WindowAdaptableExpression)
   505  		if !ok {
   506  			err := fmt.Errorf("function is not a window adaptable exprssion: %s", f.FunctionName())
   507  			b.handleErr(err)
   508  		}
   509  		if err != nil {
   510  			b.handleErr(err)
   511  		}
   512  	}
   513  
   514  	def := b.buildWindowDef(inScope, over)
   515  	switch w := win.(type) {
   516  	case sql.WindowAdaptableExpression:
   517  		win = w.WithWindow(def)
   518  	}
   519  
   520  	col := scopeColumn{col: strings.ToLower(win.String()), scalar: win, typ: win.Type(), nullable: win.IsNullable()}
   521  	id := inScope.newColumn(col)
   522  	col.id = id
   523  	win = win.WithId(sql.ColumnId(id)).(sql.WindowAdaptableExpression)
   524  	inScope.cols[len(inScope.cols)-1].scalar = win
   525  	col.scalar = win
   526  	inScope.windowFuncs = append(inScope.windowFuncs, col)
   527  	return col.scalarGf()
   528  }
   529  
   530  func (b *Builder) buildWindow(fromScope, projScope *scope) *scope {
   531  	if len(fromScope.windowFuncs) == 0 {
   532  		return fromScope
   533  	}
   534  	// passthrough dependency cols plus window funcs
   535  	var selectExprs []sql.Expression
   536  	var selectGfs []sql.Expression
   537  	selectStr := make(map[string]bool)
   538  	for _, col := range fromScope.windowFuncs {
   539  		e := col.scalar
   540  		if !selectStr[strings.ToLower(e.String())] {
   541  			switch e.(type) {
   542  			case sql.WindowAdaptableExpression:
   543  				selectStr[strings.ToLower(e.String())] = true
   544  				selectExprs = append(selectExprs, e)
   545  				selectGfs = append(selectGfs, col.scalarGf())
   546  			default:
   547  				err := fmt.Errorf("expected window function to be sql.WindowAggregation")
   548  				b.handleErr(err)
   549  			}
   550  		}
   551  	}
   552  	var aliases []sql.Expression
   553  	for _, col := range projScope.cols {
   554  		// eval aliases in project scope
   555  		switch e := col.scalar.(type) {
   556  		case *expression.Alias:
   557  			if !e.Unreferencable() {
   558  				aliases = append(aliases, e.WithId(sql.ColumnId(col.id)).(*expression.Alias))
   559  			}
   560  		default:
   561  		}
   562  
   563  		// projection dependencies -> table cols needed above
   564  		transform.InspectExpr(col.scalar, func(e sql.Expression) bool {
   565  			switch e := e.(type) {
   566  			case *expression.GetField:
   567  				colName := strings.ToLower(e.String())
   568  				if !selectStr[colName] {
   569  					selectExprs = append(selectExprs, e)
   570  					selectGfs = append(selectGfs, e)
   571  					selectStr[colName] = true
   572  				}
   573  			default:
   574  			}
   575  			return false
   576  		})
   577  	}
   578  	for _, e := range fromScope.extraCols {
   579  		// accessory cols used by ORDER_BY, HAVING
   580  		if !selectStr[e.String()] {
   581  			selectExprs = append(selectExprs, e.scalarGf())
   582  			selectGfs = append(selectGfs, e.scalarGf())
   583  			selectStr[e.String()] = true
   584  		}
   585  	}
   586  
   587  	outScope := fromScope
   588  	window := plan.NewWindow(selectExprs, fromScope.node)
   589  	fromScope.node = window
   590  
   591  	if len(aliases) > 0 {
   592  		outScope.node = plan.NewProject(append(selectGfs, aliases...), outScope.node)
   593  	}
   594  
   595  	return outScope
   596  }
   597  
   598  func (b *Builder) buildNamedWindows(fromScope *scope, window ast.Window) {
   599  	// topo sort first
   600  	adj := make(map[string]*ast.WindowDef)
   601  	for _, w := range window {
   602  		adj[w.Name.Lowered()] = w
   603  	}
   604  
   605  	var topo []*ast.WindowDef
   606  	var seen map[string]bool
   607  	var dfs func(string)
   608  	dfs = func(name string) {
   609  		if ok, _ := seen[name]; ok {
   610  			b.handleErr(sql.ErrCircularWindowInheritance.New())
   611  		}
   612  		seen[name] = true
   613  		cur := adj[name]
   614  		if ref := cur.NameRef.Lowered(); ref != "" {
   615  			dfs(ref)
   616  		}
   617  		topo = append(topo, cur)
   618  	}
   619  	for _, w := range adj {
   620  		seen = make(map[string]bool)
   621  		dfs(w.Name.Lowered())
   622  	}
   623  
   624  	fromScope.windowDefs = make(map[string]*sql.WindowDefinition)
   625  	for _, w := range topo {
   626  		fromScope.windowDefs[w.Name.Lowered()] = b.buildWindowDef(fromScope, w)
   627  	}
   628  	return
   629  }
   630  
   631  func (b *Builder) buildWindowDef(fromScope *scope, def *ast.WindowDef) *sql.WindowDefinition {
   632  	if def == nil {
   633  		return nil
   634  	}
   635  
   636  	var sortFields sql.SortFields
   637  	for _, c := range def.OrderBy {
   638  		// resolve col in fromScope
   639  		e := b.buildScalar(fromScope, c.Expr)
   640  		so := sql.Ascending
   641  		if c.Direction == ast.DescScr {
   642  			so = sql.Descending
   643  		}
   644  		sf := sql.SortField{
   645  			Column: e,
   646  			Order:  so,
   647  		}
   648  		sortFields = append(sortFields, sf)
   649  	}
   650  
   651  	partitions := make([]sql.Expression, len(def.PartitionBy))
   652  	for i, expr := range def.PartitionBy {
   653  		partitions[i] = b.buildScalar(fromScope, expr)
   654  	}
   655  
   656  	frame := b.NewFrame(fromScope, def.Frame)
   657  
   658  	// According to MySQL documentation at https://dev.mysql.com/doc/refman/8.0/en/window-functions-usage.html
   659  	// "If OVER() is empty, the window consists of all query rows and the window function computes a result using all rows."
   660  	if def.OrderBy == nil && frame == nil {
   661  		frame = plan.NewRowsUnboundedPrecedingToUnboundedFollowingFrame()
   662  	}
   663  
   664  	windowDef := sql.NewWindowDefinition(partitions, sortFields, frame, def.NameRef.Lowered(), def.Name.Lowered())
   665  	if ref, ok := fromScope.windowDefs[def.NameRef.Lowered()]; ok {
   666  		// this is only safe if windows are built in topo order
   667  		windowDef = b.mergeWindowDefs(windowDef, ref)
   668  		// collapse dependencies if any reference this window
   669  		fromScope.windowDefs[windowDef.Name] = windowDef
   670  	}
   671  	return windowDef
   672  }
   673  
   674  // mergeWindowDefs combines the attributes of two window definitions or returns
   675  // an error if the two are incompatible. [def] should have a reference to
   676  // [ref] through [def.Ref], and the return value drops the reference to indicate
   677  // the two were properly combined.
   678  func (b *Builder) mergeWindowDefs(def, ref *sql.WindowDefinition) *sql.WindowDefinition {
   679  	if ref.Ref != "" {
   680  		panic("unreachable; cannot merge unresolved window definition")
   681  	}
   682  
   683  	var orderBy sql.SortFields
   684  	switch {
   685  	case len(def.OrderBy) > 0 && len(ref.OrderBy) > 0:
   686  		err := sql.ErrInvalidWindowInheritance.New("", "", "both contain order by clause")
   687  		b.handleErr(err)
   688  	case len(def.OrderBy) > 0:
   689  		orderBy = def.OrderBy
   690  	case len(ref.OrderBy) > 0:
   691  		orderBy = ref.OrderBy
   692  	default:
   693  	}
   694  
   695  	var partitionBy []sql.Expression
   696  	switch {
   697  	case len(def.PartitionBy) > 0 && len(ref.PartitionBy) > 0:
   698  		err := sql.ErrInvalidWindowInheritance.New("", "", "both contain partition by clause")
   699  		b.handleErr(err)
   700  	case len(def.PartitionBy) > 0:
   701  		partitionBy = def.PartitionBy
   702  	case len(ref.PartitionBy) > 0:
   703  		partitionBy = ref.PartitionBy
   704  	default:
   705  		partitionBy = []sql.Expression{}
   706  	}
   707  
   708  	var frame sql.WindowFrame
   709  	switch {
   710  	case def.Frame != nil && ref.Frame != nil:
   711  		_, isDefDefaultFrame := def.Frame.(*plan.RowsUnboundedPrecedingToUnboundedFollowingFrame)
   712  		_, isRefDefaultFrame := ref.Frame.(*plan.RowsUnboundedPrecedingToUnboundedFollowingFrame)
   713  
   714  		// if both frames are set and one is RowsUnboundedPrecedingToUnboundedFollowingFrame (default),
   715  		// we should use the other frame
   716  		if isDefDefaultFrame {
   717  			frame = ref.Frame
   718  		} else if isRefDefaultFrame {
   719  			frame = def.Frame
   720  		} else {
   721  			// if both frames have identical string representations, use either one
   722  			df := def.Frame.String()
   723  			rf := ref.Frame.String()
   724  			if df != rf {
   725  				err := sql.ErrInvalidWindowInheritance.New("", "", "both contain different frame clauses")
   726  				b.handleErr(err)
   727  			}
   728  			frame = def.Frame
   729  		}
   730  	case def.Frame != nil:
   731  		frame = def.Frame
   732  	case ref.Frame != nil:
   733  		frame = ref.Frame
   734  	default:
   735  	}
   736  
   737  	return sql.NewWindowDefinition(partitionBy, orderBy, frame, "", def.Name)
   738  }
   739  
   740  func (b *Builder) analyzeHaving(fromScope, projScope *scope, having *ast.Where) {
   741  	// build having filter expr
   742  	// aggregates added to fromScope.groupBy
   743  	// can see projScope outputs
   744  	if having == nil {
   745  		return
   746  	}
   747  
   748  	ast.Walk(func(node ast.SQLNode) (bool, error) {
   749  		switch n := node.(type) {
   750  		case *ast.Subquery:
   751  			return false, nil
   752  		case *ast.FuncExpr:
   753  			name := n.Name.Lowered()
   754  			if isAggregateFunc(name) {
   755  				// record aggregate
   756  				// TODO: this should get projScope as well
   757  				_ = b.buildAggregateFunc(fromScope, name, n)
   758  			} else if isWindowFunc(name) {
   759  				_ = b.buildWindowFunc(fromScope, name, n, (*ast.WindowDef)(n.Over))
   760  			}
   761  		case *ast.ColName:
   762  			// add to extra cols
   763  			dbName := strings.ToLower(n.Qualifier.Qualifier.String())
   764  			tblName := strings.ToLower(n.Qualifier.Name.String())
   765  			colName := strings.ToLower(n.Name.String())
   766  			c, ok := fromScope.resolveColumn(dbName, tblName, colName, true, false)
   767  			if ok {
   768  				c.scalar = expression.NewGetFieldWithTable(int(c.id), 0, c.typ, c.db, c.table, c.col, c.nullable)
   769  				fromScope.addExtraColumn(c)
   770  				break
   771  			}
   772  			c, ok = projScope.resolveColumn(dbName, tblName, colName, false, true)
   773  			if ok {
   774  				// references projection alias
   775  				break
   776  			}
   777  			err := sql.ErrColumnNotFound.New(n.Name)
   778  			b.handleErr(err)
   779  		}
   780  		return true, nil
   781  	}, having.Expr)
   782  }
   783  
   784  func (b *Builder) buildInnerProj(fromScope, projScope *scope) *scope {
   785  	outScope := fromScope
   786  	var proj []sql.Expression
   787  
   788  	// eval aliases in project scope
   789  	for _, col := range projScope.cols {
   790  		switch e := col.scalar.(type) {
   791  		case *expression.Alias:
   792  			if !e.Unreferencable() {
   793  				proj = append(proj, e.WithId(sql.ColumnId(col.id)).(*expression.Alias))
   794  			}
   795  		}
   796  	}
   797  
   798  	aliasCnt := len(proj)
   799  
   800  	if len(proj) == 0 && !(len(fromScope.cols) == 1 && fromScope.cols[0].id == 0) {
   801  		// remove redundant projection unless it is the single dual table column
   802  		return outScope
   803  	}
   804  
   805  	for _, c := range fromScope.cols {
   806  		proj = append(proj, c.scalarGf())
   807  	}
   808  
   809  	// todo: fulltext indexes depend on match alias first
   810  	proj = append(proj[aliasCnt:], proj[:aliasCnt]...)
   811  
   812  	if len(proj) > 0 {
   813  		outScope.node = plan.NewProject(proj, outScope.node)
   814  	}
   815  
   816  	return outScope
   817  }
   818  
   819  // getMatchingCol returns the column in cols that matches the name, if it exists
   820  func getMatchingCol(cols []scopeColumn, name string) (scopeColumn, bool) {
   821  	for _, c := range cols {
   822  		if strings.EqualFold(c.col, name) {
   823  			return c, true
   824  		}
   825  	}
   826  	return scopeColumn{}, false
   827  }
   828  
   829  func (b *Builder) buildHaving(fromScope, projScope, outScope *scope, having *ast.Where) {
   830  	// expressions in having can be from aggOut or projScop
   831  	if having == nil {
   832  		return
   833  	}
   834  	if fromScope.groupBy == nil {
   835  		fromScope.initGroupBy()
   836  	}
   837  
   838  	havingScope := b.newScope()
   839  	if fromScope.parent != nil {
   840  		havingScope.parent = fromScope.parent
   841  	}
   842  
   843  	// add columns from fromScope referenced in the groupBy
   844  	for _, c := range fromScope.groupBy.inCols {
   845  		if !havingScope.colset.Contains(sql.ColumnId(c.id)) {
   846  			havingScope.addColumn(c)
   847  		}
   848  	}
   849  
   850  	// add columns from fromScope referenced in any aggregate expressions
   851  	for _, c := range fromScope.groupBy.aggregations() {
   852  		transform.InspectExpr(c.scalar, func(e sql.Expression) bool {
   853  			switch e := e.(type) {
   854  			case *expression.GetField:
   855  				col, found := getMatchingCol(fromScope.cols, e.Name())
   856  				if found && !havingScope.colset.Contains(sql.ColumnId(col.id)) {
   857  					havingScope.addColumn(col)
   858  				}
   859  			}
   860  			return false
   861  		})
   862  	}
   863  
   864  	// Add columns from projScope referenced in any aggregate expressions, that are not already in the havingScope
   865  	// This prevents aliases with the same name from overriding columns in the fromScope
   866  	// Additionally, the original name from plain aliases (not expressions) are added to havingScope
   867  	for _, c := range projScope.cols {
   868  		if !havingScope.colset.Contains(sql.ColumnId(c.id)) {
   869  			havingScope.addColumn(c)
   870  		}
   871  		// The unaliased column is allowed in having clauses regardless if it is just an aliased getfield and not an expression
   872  		alias, isAlias := c.scalar.(*expression.Alias)
   873  		if !isAlias {
   874  			continue
   875  		}
   876  		gf, isGetField := alias.Child.(*expression.GetField)
   877  		if !isGetField {
   878  			continue
   879  		}
   880  		col, found := getMatchingCol(fromScope.cols, gf.Name())
   881  		if found && !havingScope.colset.Contains(sql.ColumnId(col.id)) {
   882  			havingScope.addColumn(col)
   883  		}
   884  	}
   885  
   886  	havingScope.groupBy = fromScope.groupBy
   887  	h := b.buildScalar(havingScope, having.Expr)
   888  	outScope.node = plan.NewHaving(h, outScope.node)
   889  	return
   890  }