github.com/dolthub/go-mysql-server@v0.18.0/sql/planbuilder/aggregates.go (about) 1 // Copyright 2023 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package planbuilder 16 17 import ( 18 "fmt" 19 "sort" 20 "strings" 21 22 ast "github.com/dolthub/vitess/go/vt/sqlparser" 23 24 "github.com/dolthub/go-mysql-server/sql" 25 "github.com/dolthub/go-mysql-server/sql/expression" 26 "github.com/dolthub/go-mysql-server/sql/expression/function/aggregation" 27 "github.com/dolthub/go-mysql-server/sql/plan" 28 "github.com/dolthub/go-mysql-server/sql/transform" 29 "github.com/dolthub/go-mysql-server/sql/types" 30 ) 31 32 var _ ast.Expr = (*aggregateInfo)(nil) 33 34 type groupBy struct { 35 inCols []scopeColumn 36 outScope *scope 37 aggs map[string]scopeColumn 38 grouping map[string]bool 39 } 40 41 func (g *groupBy) addInCol(c scopeColumn) { 42 g.inCols = append(g.inCols, c) 43 } 44 45 func (g *groupBy) addOutCol(c scopeColumn) columnId { 46 return g.outScope.newColumn(c) 47 } 48 49 func (g *groupBy) hasAggs() bool { 50 return len(g.aggs) > 0 51 } 52 53 func (g *groupBy) aggregations() []scopeColumn { 54 aggregations := make([]scopeColumn, 0, len(g.aggs)) 55 for _, agg := range g.aggs { 56 aggregations = append(aggregations, agg) 57 } 58 sort.Slice(aggregations, func(i, j int) bool { 59 return aggregations[i].scalar.String() < aggregations[j].scalar.String() 60 }) 61 return aggregations 62 } 63 64 func (g *groupBy) addAggStr(c scopeColumn) { 65 if g.aggs == nil { 66 g.aggs = make(map[string]scopeColumn) 67 } 68 g.aggs[strings.ToLower(c.scalar.String())] = c 69 } 70 71 func (g *groupBy) getAggRef(name string) sql.Expression { 72 if g.aggs == nil { 73 return nil 74 } 75 ret, _ := g.aggs[name] 76 if ret.empty() { 77 return nil 78 } 79 return ret.scalarGf() 80 } 81 82 type aggregateInfo struct { 83 ast.Expr 84 } 85 86 func (b *Builder) needsAggregation(fromScope *scope, sel *ast.Select) bool { 87 return len(sel.GroupBy) > 0 || 88 (fromScope.groupBy != nil && fromScope.groupBy.hasAggs()) 89 } 90 91 func (b *Builder) buildGroupingCols(fromScope, projScope *scope, groupby ast.GroupBy, selects ast.SelectExprs) []sql.Expression { 92 // grouping col will either be: 93 // 1) alias into targets 94 // 2) a column reference 95 // 3) an index into selects 96 // 4) a simple non-aggregate expression 97 groupings := make([]sql.Expression, 0) 98 if fromScope.groupBy == nil { 99 fromScope.initGroupBy() 100 } 101 g := fromScope.groupBy 102 for _, e := range groupby { 103 var col scopeColumn 104 switch e := e.(type) { 105 case *ast.ColName: 106 var ok bool 107 // GROUP BY binds to column references before projections. 108 dbName := strings.ToLower(e.Qualifier.Qualifier.String()) 109 tblName := strings.ToLower(e.Qualifier.Name.String()) 110 colName := strings.ToLower(e.Name.String()) 111 col, ok = fromScope.resolveColumn(dbName, tblName, colName, true, false) 112 if !ok { 113 col, ok = projScope.resolveColumn(dbName, tblName, colName, true, true) 114 } 115 116 if !ok { 117 b.handleErr(sql.ErrColumnNotFound.New(e.Name.String())) 118 } 119 case *ast.SQLVal: 120 // literal -> index into targets 121 replace := b.normalizeValArg(e) 122 val, ok := replace.(*ast.SQLVal) 123 if !ok { 124 // ast.NullVal 125 continue 126 } 127 if val.Type == ast.IntVal { 128 lit := b.convertInt(string(val.Val), 10) 129 idx, _, err := types.Int64.Convert(lit.Value()) 130 if err != nil { 131 b.handleErr(err) 132 } 133 intIdx, ok := idx.(int64) 134 if !ok { 135 b.handleErr(fmt.Errorf("expected integer order by literal")) 136 } 137 if intIdx < 1 { 138 b.handleErr(fmt.Errorf("expected positive integer order by literal")) 139 } 140 col = projScope.cols[intIdx-1] 141 } 142 default: 143 expr := b.buildScalar(fromScope, e) 144 col = scopeColumn{ 145 col: expr.String(), 146 typ: nil, 147 scalar: expr, 148 nullable: expr.IsNullable(), 149 } 150 } 151 if col.scalar == nil { 152 gf := expression.NewGetFieldWithTable(int(col.id), int(col.tableId), col.typ, col.db, col.table, col.col, col.nullable) 153 id, ok := fromScope.getExpr(gf.String(), true) 154 if !ok { 155 err := sql.ErrColumnNotFound.New(gf.String()) 156 b.handleErr(err) 157 } 158 col.scalar = gf.WithIndex(int(id)) 159 } 160 g.addInCol(col) 161 groupings = append(groupings, col.scalar) 162 } 163 164 return groupings 165 } 166 167 func (b *Builder) buildAggregation(fromScope, projScope *scope, groupingCols []sql.Expression) *scope { 168 // GROUP_BY consists of: 169 // - input arguments projection 170 // - grouping cols projection 171 // - aggregate expressions 172 // - output projection 173 if fromScope.groupBy == nil { 174 fromScope.initGroupBy() 175 } 176 177 group := fromScope.groupBy 178 outScope := group.outScope 179 // select columns: 180 // - aggs 181 // - extra columns needed by having, order by, select 182 var selectExprs []sql.Expression 183 var selectGfs []sql.Expression 184 selectStr := make(map[string]bool) 185 for _, e := range group.aggregations() { 186 if !selectStr[strings.ToLower(e.String())] { 187 selectExprs = append(selectExprs, e.scalar) 188 selectGfs = append(selectGfs, e.scalarGf()) 189 selectStr[strings.ToLower(e.String())] = true 190 } 191 } 192 var aliases []sql.Expression 193 for _, col := range projScope.cols { 194 // eval aliases in project scope 195 switch e := col.scalar.(type) { 196 case *expression.Alias: 197 if !e.Unreferencable() { 198 aliases = append(aliases, e.WithId(sql.ColumnId(col.id)).(*expression.Alias)) 199 } 200 default: 201 } 202 203 // projection dependencies -> table cols needed above 204 transform.InspectExpr(col.scalar, func(e sql.Expression) bool { 205 switch e := e.(type) { 206 case *expression.GetField: 207 colName := strings.ToLower(e.String()) 208 if !selectStr[colName] { 209 selectExprs = append(selectExprs, e) 210 selectGfs = append(selectGfs, e) 211 selectStr[colName] = true 212 } 213 default: 214 } 215 return false 216 }) 217 } 218 for _, e := range fromScope.extraCols { 219 // accessory cols used by ORDER_BY, HAVING 220 if !selectStr[e.String()] { 221 selectExprs = append(selectExprs, e.scalarGf()) 222 selectGfs = append(selectGfs, e.scalarGf()) 223 224 selectStr[e.String()] = true 225 } 226 } 227 gb := plan.NewGroupBy(selectExprs, groupingCols, fromScope.node) 228 outScope.node = gb 229 230 if len(aliases) > 0 { 231 outScope.node = plan.NewProject(append(selectGfs, aliases...), outScope.node) 232 } 233 return outScope 234 } 235 236 func isAggregateFunc(name string) bool { 237 switch name { 238 case "avg", "bit_and", "bit_or", "bit_xor", "count", 239 "group_concat", "json_arrayagg", "json_objectagg", 240 "max", "min", "std", "stddev_pop", "stddev_samp", 241 "stddev", "sum", "var_pop", "var_samp", "variance", 242 "first", "last", "any_value": 243 return true 244 default: 245 return false 246 } 247 } 248 249 // buildAggregateFunc tags aggregate functions in the correct scope 250 // and makes the aggregate available for reference by other clauses. 251 func (b *Builder) buildAggregateFunc(inScope *scope, name string, e *ast.FuncExpr) sql.Expression { 252 if len(inScope.windowFuncs) > 0 { 253 err := sql.ErrNonAggregatedColumnWithoutGroupBy.New() 254 b.handleErr(err) 255 } 256 257 if inScope.groupBy == nil { 258 inScope.initGroupBy() 259 } 260 gb := inScope.groupBy 261 262 if name == "count" { 263 if _, ok := e.Exprs[0].(*ast.StarExpr); ok { 264 var agg sql.Aggregation 265 if e.Distinct { 266 agg = aggregation.NewCountDistinct(expression.NewLiteral(1, types.Int64)) 267 } else { 268 agg = aggregation.NewCount(expression.NewLiteral(1, types.Int64)) 269 } 270 aggName := strings.ToLower(agg.String()) 271 gf := gb.getAggRef(aggName) 272 if gf != nil { 273 // if we've already computed use reference here 274 return gf 275 } 276 277 col := scopeColumn{col: strings.ToLower(agg.String()), scalar: agg, typ: agg.Type(), nullable: agg.IsNullable()} 278 id := gb.outScope.newColumn(col) 279 col.id = id 280 281 agg = agg.WithId(sql.ColumnId(id)).(sql.Aggregation) 282 gb.outScope.cols[len(gb.outScope.cols)-1].scalar = agg 283 col.scalar = agg 284 285 gb.addAggStr(col) 286 return col.scalarGf() 287 } 288 } 289 290 if name == "jsonarray" { 291 // TODO we don't have any tests for this 292 if _, ok := e.Exprs[0].(*ast.StarExpr); ok { 293 var agg sql.Aggregation 294 agg = aggregation.NewJsonArray(expression.NewLiteral(expression.NewStar(), types.Int64)) 295 //if e.Distinct { 296 // agg = plan.NewDistinct(expression.NewLiteral(1, types.Int64)) 297 //} 298 aggName := strings.ToLower(agg.String()) 299 gf := gb.getAggRef(aggName) 300 if gf != nil { 301 // if we've already computed use reference here 302 return gf 303 } 304 305 col := scopeColumn{col: strings.ToLower(agg.String()), scalar: agg, typ: agg.Type(), nullable: agg.IsNullable()} 306 id := gb.outScope.newColumn(col) 307 308 agg = agg.WithId(sql.ColumnId(id)).(*aggregation.JsonArray) 309 gb.outScope.cols[len(gb.outScope.cols)-1].scalar = agg 310 col.scalar = agg 311 312 col.id = id 313 gb.addAggStr(col) 314 return col.scalarGf() 315 } 316 } 317 318 var args []sql.Expression 319 for _, arg := range e.Exprs { 320 e := b.selectExprToExpression(inScope, arg) 321 switch e := e.(type) { 322 case *expression.GetField: 323 if e.TableId() == 0 { 324 // TODO: not sure where this came from but it's not true 325 // aliases are not valid aggregate arguments, the alias must be masking a column 326 gf := b.selectExprToExpression(inScope.parent, arg) 327 var ok bool 328 e, ok = gf.(*expression.GetField) 329 if !ok || e.TableId() == 0 { 330 b.handleErr(fmt.Errorf("failed to resolve aggregate column argument: %s", gf)) 331 } 332 } 333 args = append(args, e) 334 col := scopeColumn{tableId: e.TableID(), db: e.Database(), table: e.Table(), col: e.Name(), scalar: e, typ: e.Type(), nullable: e.IsNullable()} 335 gb.addInCol(col) 336 case *expression.Star: 337 err := sql.ErrStarUnsupported.New() 338 b.handleErr(err) 339 case *plan.Subquery: 340 args = append(args, e) 341 col := scopeColumn{col: e.QueryString, scalar: e, typ: e.Type()} 342 gb.addInCol(col) 343 default: 344 args = append(args, e) 345 col := scopeColumn{col: e.String(), scalar: e, typ: e.Type()} 346 gb.addInCol(col) 347 } 348 } 349 350 var agg sql.Aggregation 351 if e.Distinct && name == "count" { 352 agg = aggregation.NewCountDistinct(args...) 353 } else { 354 355 // NOTE: Not all aggregate functions support DISTINCT. Fortunately, the vitess parser will throw 356 // errors for when DISTINCT is used on aggregate functions that don't support DISTINCT. 357 if e.Distinct { 358 if len(e.Exprs) != 1 { 359 err := sql.ErrUnsupportedSyntax.New("more than one expression with distinct") 360 b.handleErr(err) 361 } 362 363 args[0] = expression.NewDistinctExpression(args[0]) 364 } 365 366 f, err := b.cat.Function(b.ctx, name) 367 if err != nil { 368 b.handleErr(err) 369 } 370 371 newInst, err := f.NewInstance(args) 372 if err != nil { 373 b.handleErr(err) 374 } 375 var ok bool 376 agg, ok = newInst.(sql.Aggregation) 377 if !ok { 378 err := fmt.Errorf("expected function to be aggregation: %s", f.FunctionName()) 379 b.handleErr(err) 380 } 381 } 382 383 aggType := agg.Type() 384 if name == "avg" || name == "sum" { 385 aggType = types.Float64 386 } 387 388 aggName := strings.ToLower(plan.AliasSubqueryString(agg)) 389 if id, ok := gb.outScope.getExpr(aggName, true); ok { 390 // if we've already computed use reference here 391 gf := expression.NewGetFieldWithTable(int(id), 0, aggType, "", "", aggName, agg.IsNullable()) 392 return gf 393 } 394 395 col := scopeColumn{col: aggName, scalar: agg, typ: aggType, nullable: agg.IsNullable()} 396 id := gb.outScope.newColumn(col) 397 398 agg = agg.WithId(sql.ColumnId(id)).(sql.Aggregation) 399 gb.outScope.cols[len(gb.outScope.cols)-1].scalar = agg 400 col.scalar = agg 401 402 col.id = id 403 gb.addAggStr(col) 404 return col.scalarGf() 405 } 406 407 func (b *Builder) buildGroupConcat(inScope *scope, e *ast.GroupConcatExpr) sql.Expression { 408 if inScope.groupBy == nil { 409 inScope.initGroupBy() 410 } 411 gb := inScope.groupBy 412 413 args := make([]sql.Expression, len(e.Exprs)) 414 for i, a := range e.Exprs { 415 args[i] = b.selectExprToExpression(inScope, a) 416 } 417 418 separatorS := "," 419 if !e.Separator.DefaultSeparator { 420 separatorS = e.Separator.SeparatorString 421 } 422 423 orderByScope := b.analyzeOrderBy(inScope, inScope, e.OrderBy) 424 var sortFields sql.SortFields 425 for _, c := range orderByScope.cols { 426 so := sql.Ascending 427 if c.descending { 428 so = sql.Descending 429 } 430 scalar := c.scalar 431 if scalar == nil { 432 scalar = c.scalarGf() 433 } 434 sf := sql.SortField{ 435 Column: scalar, 436 Order: so, 437 } 438 sortFields = append(sortFields, sf) 439 } 440 441 //TODO: this should be acquired at runtime, not at parse time, so fix this 442 gcml, err := b.ctx.GetSessionVariable(b.ctx, "group_concat_max_len") 443 if err != nil { 444 b.handleErr(err) 445 } 446 groupConcatMaxLen := gcml.(uint64) 447 448 // todo store ref to aggregate 449 agg := aggregation.NewGroupConcat(e.Distinct, sortFields, separatorS, args, int(groupConcatMaxLen)) 450 aggName := strings.ToLower(plan.AliasSubqueryString(agg)) 451 col := scopeColumn{col: aggName, scalar: agg, typ: agg.Type(), nullable: agg.IsNullable()} 452 453 id := gb.outScope.newColumn(col) 454 455 agg = agg.WithId(sql.ColumnId(id)).(*aggregation.GroupConcat) 456 gb.outScope.cols[len(gb.outScope.cols)-1].scalar = agg 457 col.scalar = agg 458 459 gb.addAggStr(col) 460 col.id = id 461 return col.scalarGf() 462 } 463 464 func isWindowFunc(name string) bool { 465 switch name { 466 case "first", "last", "count", "sum", "any_value", 467 "avg", "max", "min", "count_distinct", "json_arrayagg", 468 "row_number", "percent_rank", "lead", "lag", 469 "first_value", "last_value", 470 "rank", "dense_rank": 471 return true 472 default: 473 return false 474 } 475 } 476 477 func (b *Builder) buildWindowFunc(inScope *scope, name string, e *ast.FuncExpr, over *ast.WindowDef) sql.Expression { 478 if inScope.groupBy != nil { 479 err := sql.ErrNonAggregatedColumnWithoutGroupBy.New() 480 b.handleErr(err) 481 } 482 483 // internal expressions can be complex, but window can't be more than alias 484 var args []sql.Expression 485 for _, arg := range e.Exprs { 486 e := b.selectExprToExpression(inScope, arg) 487 args = append(args, e) 488 } 489 490 var win sql.WindowAdaptableExpression 491 if name == "count" { 492 if _, ok := e.Exprs[0].(*ast.StarExpr); ok { 493 win = aggregation.NewCount(expression.NewLiteral(1, types.Int64)) 494 } 495 } 496 if win == nil { 497 f, err := b.cat.Function(b.ctx, name) 498 if err != nil { 499 b.handleErr(err) 500 } 501 502 newInst, err := f.NewInstance(args) 503 var ok bool 504 win, ok = newInst.(sql.WindowAdaptableExpression) 505 if !ok { 506 err := fmt.Errorf("function is not a window adaptable exprssion: %s", f.FunctionName()) 507 b.handleErr(err) 508 } 509 if err != nil { 510 b.handleErr(err) 511 } 512 } 513 514 def := b.buildWindowDef(inScope, over) 515 switch w := win.(type) { 516 case sql.WindowAdaptableExpression: 517 win = w.WithWindow(def) 518 } 519 520 col := scopeColumn{col: strings.ToLower(win.String()), scalar: win, typ: win.Type(), nullable: win.IsNullable()} 521 id := inScope.newColumn(col) 522 col.id = id 523 win = win.WithId(sql.ColumnId(id)).(sql.WindowAdaptableExpression) 524 inScope.cols[len(inScope.cols)-1].scalar = win 525 col.scalar = win 526 inScope.windowFuncs = append(inScope.windowFuncs, col) 527 return col.scalarGf() 528 } 529 530 func (b *Builder) buildWindow(fromScope, projScope *scope) *scope { 531 if len(fromScope.windowFuncs) == 0 { 532 return fromScope 533 } 534 // passthrough dependency cols plus window funcs 535 var selectExprs []sql.Expression 536 var selectGfs []sql.Expression 537 selectStr := make(map[string]bool) 538 for _, col := range fromScope.windowFuncs { 539 e := col.scalar 540 if !selectStr[strings.ToLower(e.String())] { 541 switch e.(type) { 542 case sql.WindowAdaptableExpression: 543 selectStr[strings.ToLower(e.String())] = true 544 selectExprs = append(selectExprs, e) 545 selectGfs = append(selectGfs, col.scalarGf()) 546 default: 547 err := fmt.Errorf("expected window function to be sql.WindowAggregation") 548 b.handleErr(err) 549 } 550 } 551 } 552 var aliases []sql.Expression 553 for _, col := range projScope.cols { 554 // eval aliases in project scope 555 switch e := col.scalar.(type) { 556 case *expression.Alias: 557 if !e.Unreferencable() { 558 aliases = append(aliases, e.WithId(sql.ColumnId(col.id)).(*expression.Alias)) 559 } 560 default: 561 } 562 563 // projection dependencies -> table cols needed above 564 transform.InspectExpr(col.scalar, func(e sql.Expression) bool { 565 switch e := e.(type) { 566 case *expression.GetField: 567 colName := strings.ToLower(e.String()) 568 if !selectStr[colName] { 569 selectExprs = append(selectExprs, e) 570 selectGfs = append(selectGfs, e) 571 selectStr[colName] = true 572 } 573 default: 574 } 575 return false 576 }) 577 } 578 for _, e := range fromScope.extraCols { 579 // accessory cols used by ORDER_BY, HAVING 580 if !selectStr[e.String()] { 581 selectExprs = append(selectExprs, e.scalarGf()) 582 selectGfs = append(selectGfs, e.scalarGf()) 583 selectStr[e.String()] = true 584 } 585 } 586 587 outScope := fromScope 588 window := plan.NewWindow(selectExprs, fromScope.node) 589 fromScope.node = window 590 591 if len(aliases) > 0 { 592 outScope.node = plan.NewProject(append(selectGfs, aliases...), outScope.node) 593 } 594 595 return outScope 596 } 597 598 func (b *Builder) buildNamedWindows(fromScope *scope, window ast.Window) { 599 // topo sort first 600 adj := make(map[string]*ast.WindowDef) 601 for _, w := range window { 602 adj[w.Name.Lowered()] = w 603 } 604 605 var topo []*ast.WindowDef 606 var seen map[string]bool 607 var dfs func(string) 608 dfs = func(name string) { 609 if ok, _ := seen[name]; ok { 610 b.handleErr(sql.ErrCircularWindowInheritance.New()) 611 } 612 seen[name] = true 613 cur := adj[name] 614 if ref := cur.NameRef.Lowered(); ref != "" { 615 dfs(ref) 616 } 617 topo = append(topo, cur) 618 } 619 for _, w := range adj { 620 seen = make(map[string]bool) 621 dfs(w.Name.Lowered()) 622 } 623 624 fromScope.windowDefs = make(map[string]*sql.WindowDefinition) 625 for _, w := range topo { 626 fromScope.windowDefs[w.Name.Lowered()] = b.buildWindowDef(fromScope, w) 627 } 628 return 629 } 630 631 func (b *Builder) buildWindowDef(fromScope *scope, def *ast.WindowDef) *sql.WindowDefinition { 632 if def == nil { 633 return nil 634 } 635 636 var sortFields sql.SortFields 637 for _, c := range def.OrderBy { 638 // resolve col in fromScope 639 e := b.buildScalar(fromScope, c.Expr) 640 so := sql.Ascending 641 if c.Direction == ast.DescScr { 642 so = sql.Descending 643 } 644 sf := sql.SortField{ 645 Column: e, 646 Order: so, 647 } 648 sortFields = append(sortFields, sf) 649 } 650 651 partitions := make([]sql.Expression, len(def.PartitionBy)) 652 for i, expr := range def.PartitionBy { 653 partitions[i] = b.buildScalar(fromScope, expr) 654 } 655 656 frame := b.NewFrame(fromScope, def.Frame) 657 658 // According to MySQL documentation at https://dev.mysql.com/doc/refman/8.0/en/window-functions-usage.html 659 // "If OVER() is empty, the window consists of all query rows and the window function computes a result using all rows." 660 if def.OrderBy == nil && frame == nil { 661 frame = plan.NewRowsUnboundedPrecedingToUnboundedFollowingFrame() 662 } 663 664 windowDef := sql.NewWindowDefinition(partitions, sortFields, frame, def.NameRef.Lowered(), def.Name.Lowered()) 665 if ref, ok := fromScope.windowDefs[def.NameRef.Lowered()]; ok { 666 // this is only safe if windows are built in topo order 667 windowDef = b.mergeWindowDefs(windowDef, ref) 668 // collapse dependencies if any reference this window 669 fromScope.windowDefs[windowDef.Name] = windowDef 670 } 671 return windowDef 672 } 673 674 // mergeWindowDefs combines the attributes of two window definitions or returns 675 // an error if the two are incompatible. [def] should have a reference to 676 // [ref] through [def.Ref], and the return value drops the reference to indicate 677 // the two were properly combined. 678 func (b *Builder) mergeWindowDefs(def, ref *sql.WindowDefinition) *sql.WindowDefinition { 679 if ref.Ref != "" { 680 panic("unreachable; cannot merge unresolved window definition") 681 } 682 683 var orderBy sql.SortFields 684 switch { 685 case len(def.OrderBy) > 0 && len(ref.OrderBy) > 0: 686 err := sql.ErrInvalidWindowInheritance.New("", "", "both contain order by clause") 687 b.handleErr(err) 688 case len(def.OrderBy) > 0: 689 orderBy = def.OrderBy 690 case len(ref.OrderBy) > 0: 691 orderBy = ref.OrderBy 692 default: 693 } 694 695 var partitionBy []sql.Expression 696 switch { 697 case len(def.PartitionBy) > 0 && len(ref.PartitionBy) > 0: 698 err := sql.ErrInvalidWindowInheritance.New("", "", "both contain partition by clause") 699 b.handleErr(err) 700 case len(def.PartitionBy) > 0: 701 partitionBy = def.PartitionBy 702 case len(ref.PartitionBy) > 0: 703 partitionBy = ref.PartitionBy 704 default: 705 partitionBy = []sql.Expression{} 706 } 707 708 var frame sql.WindowFrame 709 switch { 710 case def.Frame != nil && ref.Frame != nil: 711 _, isDefDefaultFrame := def.Frame.(*plan.RowsUnboundedPrecedingToUnboundedFollowingFrame) 712 _, isRefDefaultFrame := ref.Frame.(*plan.RowsUnboundedPrecedingToUnboundedFollowingFrame) 713 714 // if both frames are set and one is RowsUnboundedPrecedingToUnboundedFollowingFrame (default), 715 // we should use the other frame 716 if isDefDefaultFrame { 717 frame = ref.Frame 718 } else if isRefDefaultFrame { 719 frame = def.Frame 720 } else { 721 // if both frames have identical string representations, use either one 722 df := def.Frame.String() 723 rf := ref.Frame.String() 724 if df != rf { 725 err := sql.ErrInvalidWindowInheritance.New("", "", "both contain different frame clauses") 726 b.handleErr(err) 727 } 728 frame = def.Frame 729 } 730 case def.Frame != nil: 731 frame = def.Frame 732 case ref.Frame != nil: 733 frame = ref.Frame 734 default: 735 } 736 737 return sql.NewWindowDefinition(partitionBy, orderBy, frame, "", def.Name) 738 } 739 740 func (b *Builder) analyzeHaving(fromScope, projScope *scope, having *ast.Where) { 741 // build having filter expr 742 // aggregates added to fromScope.groupBy 743 // can see projScope outputs 744 if having == nil { 745 return 746 } 747 748 ast.Walk(func(node ast.SQLNode) (bool, error) { 749 switch n := node.(type) { 750 case *ast.Subquery: 751 return false, nil 752 case *ast.FuncExpr: 753 name := n.Name.Lowered() 754 if isAggregateFunc(name) { 755 // record aggregate 756 // TODO: this should get projScope as well 757 _ = b.buildAggregateFunc(fromScope, name, n) 758 } else if isWindowFunc(name) { 759 _ = b.buildWindowFunc(fromScope, name, n, (*ast.WindowDef)(n.Over)) 760 } 761 case *ast.ColName: 762 // add to extra cols 763 dbName := strings.ToLower(n.Qualifier.Qualifier.String()) 764 tblName := strings.ToLower(n.Qualifier.Name.String()) 765 colName := strings.ToLower(n.Name.String()) 766 c, ok := fromScope.resolveColumn(dbName, tblName, colName, true, false) 767 if ok { 768 c.scalar = expression.NewGetFieldWithTable(int(c.id), 0, c.typ, c.db, c.table, c.col, c.nullable) 769 fromScope.addExtraColumn(c) 770 break 771 } 772 c, ok = projScope.resolveColumn(dbName, tblName, colName, false, true) 773 if ok { 774 // references projection alias 775 break 776 } 777 err := sql.ErrColumnNotFound.New(n.Name) 778 b.handleErr(err) 779 } 780 return true, nil 781 }, having.Expr) 782 } 783 784 func (b *Builder) buildInnerProj(fromScope, projScope *scope) *scope { 785 outScope := fromScope 786 var proj []sql.Expression 787 788 // eval aliases in project scope 789 for _, col := range projScope.cols { 790 switch e := col.scalar.(type) { 791 case *expression.Alias: 792 if !e.Unreferencable() { 793 proj = append(proj, e.WithId(sql.ColumnId(col.id)).(*expression.Alias)) 794 } 795 } 796 } 797 798 aliasCnt := len(proj) 799 800 if len(proj) == 0 && !(len(fromScope.cols) == 1 && fromScope.cols[0].id == 0) { 801 // remove redundant projection unless it is the single dual table column 802 return outScope 803 } 804 805 for _, c := range fromScope.cols { 806 proj = append(proj, c.scalarGf()) 807 } 808 809 // todo: fulltext indexes depend on match alias first 810 proj = append(proj[aliasCnt:], proj[:aliasCnt]...) 811 812 if len(proj) > 0 { 813 outScope.node = plan.NewProject(proj, outScope.node) 814 } 815 816 return outScope 817 } 818 819 // getMatchingCol returns the column in cols that matches the name, if it exists 820 func getMatchingCol(cols []scopeColumn, name string) (scopeColumn, bool) { 821 for _, c := range cols { 822 if strings.EqualFold(c.col, name) { 823 return c, true 824 } 825 } 826 return scopeColumn{}, false 827 } 828 829 func (b *Builder) buildHaving(fromScope, projScope, outScope *scope, having *ast.Where) { 830 // expressions in having can be from aggOut or projScop 831 if having == nil { 832 return 833 } 834 if fromScope.groupBy == nil { 835 fromScope.initGroupBy() 836 } 837 838 havingScope := b.newScope() 839 if fromScope.parent != nil { 840 havingScope.parent = fromScope.parent 841 } 842 843 // add columns from fromScope referenced in the groupBy 844 for _, c := range fromScope.groupBy.inCols { 845 if !havingScope.colset.Contains(sql.ColumnId(c.id)) { 846 havingScope.addColumn(c) 847 } 848 } 849 850 // add columns from fromScope referenced in any aggregate expressions 851 for _, c := range fromScope.groupBy.aggregations() { 852 transform.InspectExpr(c.scalar, func(e sql.Expression) bool { 853 switch e := e.(type) { 854 case *expression.GetField: 855 col, found := getMatchingCol(fromScope.cols, e.Name()) 856 if found && !havingScope.colset.Contains(sql.ColumnId(col.id)) { 857 havingScope.addColumn(col) 858 } 859 } 860 return false 861 }) 862 } 863 864 // Add columns from projScope referenced in any aggregate expressions, that are not already in the havingScope 865 // This prevents aliases with the same name from overriding columns in the fromScope 866 // Additionally, the original name from plain aliases (not expressions) are added to havingScope 867 for _, c := range projScope.cols { 868 if !havingScope.colset.Contains(sql.ColumnId(c.id)) { 869 havingScope.addColumn(c) 870 } 871 // The unaliased column is allowed in having clauses regardless if it is just an aliased getfield and not an expression 872 alias, isAlias := c.scalar.(*expression.Alias) 873 if !isAlias { 874 continue 875 } 876 gf, isGetField := alias.Child.(*expression.GetField) 877 if !isGetField { 878 continue 879 } 880 col, found := getMatchingCol(fromScope.cols, gf.Name()) 881 if found && !havingScope.colset.Contains(sql.ColumnId(col.id)) { 882 havingScope.addColumn(col) 883 } 884 } 885 886 havingScope.groupBy = fromScope.groupBy 887 h := b.buildScalar(havingScope, having.Expr) 888 outScope.node = plan.NewHaving(h, outScope.node) 889 return 890 }