github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/opt/optbuilder/window.go (about) 1 // Copyright 2019 The Cockroach Authors. 2 // 3 // Use of this software is governed by the Business Source License 4 // included in the file licenses/BSL.txt. 5 // 6 // As of the Change Date specified in that file, in accordance with 7 // the Business Source License, use of this software will be governed 8 // by the Apache License, Version 2.0, included in the file 9 // licenses/APL.txt. 10 11 package optbuilder 12 13 import ( 14 "context" 15 "fmt" 16 17 "github.com/cockroachdb/cockroach/pkg/sql/opt" 18 "github.com/cockroachdb/cockroach/pkg/sql/opt/memo" 19 "github.com/cockroachdb/cockroach/pkg/sql/opt/props/physical" 20 "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" 21 "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" 22 "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" 23 "github.com/cockroachdb/cockroach/pkg/sql/types" 24 "github.com/cockroachdb/errors" 25 ) 26 27 // windowInfo stores information about a window function call. 28 type windowInfo struct { 29 *tree.FuncExpr 30 31 def memo.FunctionPrivate 32 33 // col is the output column of the aggregation. 34 col *scopeColumn 35 } 36 37 // Walk is part of the tree.Expr interface. 38 func (w *windowInfo) Walk(v tree.Visitor) tree.Expr { 39 return w 40 } 41 42 // TypeCheck is part of the tree.Expr interface. 43 func (w *windowInfo) TypeCheck( 44 ctx context.Context, semaCtx *tree.SemaContext, desired *types.T, 45 ) (tree.TypedExpr, error) { 46 if _, err := w.FuncExpr.TypeCheck(ctx, semaCtx, desired); err != nil { 47 return nil, err 48 } 49 return w, nil 50 } 51 52 // Eval is part of the tree.TypedExpr interface. 53 func (w *windowInfo) Eval(_ *tree.EvalContext) (tree.Datum, error) { 54 panic(errors.AssertionFailedf("windowInfo must be replaced before evaluation")) 55 } 56 57 var _ tree.Expr = &windowInfo{} 58 var _ tree.TypedExpr = &windowInfo{} 59 60 var unboundedStartBound = &tree.WindowFrameBound{BoundType: tree.UnboundedPreceding} 61 var unboundedEndBound = &tree.WindowFrameBound{BoundType: tree.UnboundedFollowing} 62 var defaultStartBound = &tree.WindowFrameBound{BoundType: tree.UnboundedPreceding} 63 var defaultEndBound = &tree.WindowFrameBound{BoundType: tree.CurrentRow} 64 65 // buildWindow adds any window functions on top of the expression. 66 func (b *Builder) buildWindow(outScope *scope, inScope *scope) { 67 if len(inScope.windows) == 0 { 68 return 69 } 70 71 argLists := make([][]opt.ScalarExpr, len(inScope.windows)) 72 partitions := make([]opt.ColSet, len(inScope.windows)) 73 orderings := make([]physical.OrderingChoice, len(inScope.windows)) 74 filterCols := make([]opt.ColumnID, len(inScope.windows)) 75 defs := make([]*tree.WindowDef, len(inScope.windows)) 76 windowFrames := make([]tree.WindowFrame, len(inScope.windows)) 77 argScope := outScope.push() 78 argScope.appendColumnsFromScope(outScope) 79 80 // The arguments to a given window function need to be columns in the input 81 // relation. Build a projection that produces those values to go underneath 82 // the window functions. 83 // TODO(justin): this is unfortunate in common cases where the arguments are 84 // constant, since we'll be projecting an extra column in every row. It 85 // would be good if the windower supported being specified with constant 86 // values. 87 for i := range inScope.windows { 88 w := inScope.windows[i].expr.(*windowInfo) 89 90 def := w.WindowDef 91 defs[i] = def 92 93 argExprs := b.getTypedWindowArgs(w) 94 95 // Build the appropriate arguments. 96 argLists[i] = b.buildWindowArgs(argExprs, i, w.def.Name, inScope, argScope) 97 98 // Build appropriate partitions. 99 partitions[i] = b.buildWindowPartition(def.Partitions, i, w.def.Name, inScope, argScope) 100 101 // Build appropriate orderings. 102 ord := b.buildWindowOrdering(def.OrderBy, i, w.def.Name, inScope, argScope) 103 orderings[i].FromOrdering(ord) 104 105 if def.Frame != nil { 106 windowFrames[i] = *def.Frame 107 } 108 109 if w.Filter != nil { 110 col := b.buildFilterCol(w.Filter, i, w.def.Name, inScope, argScope) 111 filterCols[i] = col.id 112 } 113 114 // Fill this in with the default so that we don't need nil checks 115 // elsewhere. 116 if windowFrames[i].Bounds.StartBound == nil { 117 windowFrames[i].Bounds.StartBound = defaultStartBound 118 } 119 if windowFrames[i].Bounds.EndBound == nil { 120 // Some sources appear to say that the presence of an ORDER BY changes 121 // this between CURRENT ROW and UNBOUNDED FOLLOWING, but in reality, what 122 // CURRENT ROW means is the *last row which is a peer of this row* (a 123 // peer being a row which agrees on the ordering columns), so if there is 124 // no ORDER BY, every row is a peer with every other row in its 125 // partition, which means the CURRENT ROW and UNBOUNDED FOLLOWING are 126 // equivalent. 127 windowFrames[i].Bounds.EndBound = defaultEndBound 128 } 129 } 130 131 b.constructProjectForScope(outScope, argScope) 132 outScope.expr = argScope.expr 133 134 var referencedCols opt.ColSet 135 // frames accumulates the set of distinct window frames we're computing over 136 // so that we can group functions over the same partition and ordering. 137 frames := make([]memo.WindowExpr, 0, len(inScope.windows)) 138 for i := range inScope.windows { 139 w := inScope.windows[i].expr.(*windowInfo) 140 141 frameIdx := b.findMatchingFrameIndex(&frames, partitions[i], orderings[i]) 142 143 fn := b.constructWindowFn(w.def.Name, argLists[i]) 144 145 if windowFrames[i].Bounds.StartBound.OffsetExpr != nil { 146 fn = b.factory.ConstructWindowFromOffset( 147 fn, 148 b.buildScalar( 149 w.WindowDef.Frame.Bounds.StartBound.OffsetExpr.(tree.TypedExpr), 150 inScope, 151 nil, 152 nil, 153 &referencedCols, 154 ), 155 ) 156 } 157 158 if windowFrames[i].Bounds.EndBound.OffsetExpr != nil { 159 fn = b.factory.ConstructWindowToOffset( 160 fn, 161 b.buildScalar( 162 w.WindowDef.Frame.Bounds.EndBound.OffsetExpr.(tree.TypedExpr), 163 inScope, 164 nil, 165 nil, 166 &referencedCols, 167 ), 168 ) 169 } 170 171 if !referencedCols.Empty() { 172 panic( 173 pgerror.Newf( 174 pgcode.InvalidColumnReference, 175 "argument of %s must not contain variables", 176 tree.WindowModeName(windowFrames[i].Mode), 177 ), 178 ) 179 } 180 181 if filterCols[i] != 0 { 182 fn = b.factory.ConstructAggFilter( 183 fn, 184 b.factory.ConstructVariable(filterCols[i]), 185 ) 186 } 187 188 frames[frameIdx].Windows = append(frames[frameIdx].Windows, 189 b.factory.ConstructWindowsItem( 190 fn, 191 &memo.WindowsItemPrivate{ 192 Frame: memo.WindowFrame{ 193 Mode: windowFrames[i].Mode, 194 StartBoundType: windowFrames[i].Bounds.StartBound.BoundType, 195 EndBoundType: windowFrames[i].Bounds.EndBound.BoundType, 196 FrameExclusion: windowFrames[i].Exclusion, 197 }, 198 Col: w.col.id, 199 }, 200 ), 201 ) 202 } 203 204 for _, f := range frames { 205 outScope.expr = b.factory.ConstructWindow(outScope.expr, f.Windows, &f.WindowPrivate) 206 } 207 } 208 209 // buildAggregationAsWindow builds the aggregation operators as window functions. 210 // Returns the output scope for the aggregation operation. 211 // Consider the following query that uses an ordered aggregation: 212 // 213 // SELECT array_agg(col1 ORDER BY col1) FROM tab 214 // 215 // To support this ordering, we build the aggregate as a window function like below: 216 // 217 // scalar-group-by 218 // ├── columns: array_agg:2(int[]) 219 // ├── window partition=() ordering=+1 220 // │ ├── columns: col1:1(int!null) array_agg:2(int[]) 221 // │ ├── scan tab 222 // │ │ └── columns: col1:1(int!null) 223 // │ └── windows 224 // │ └── windows-item: range from unbounded to unbounded [type=int[]] 225 // │ └── array-agg [type=int[]] 226 // │ └── variable: col1 [type=int] 227 // └── aggregations 228 // └── const-agg [type=int[]] 229 // └── variable: array_agg [type=int[]] 230 func (b *Builder) buildAggregationAsWindow( 231 groupingColSet opt.ColSet, having opt.ScalarExpr, fromScope *scope, 232 ) *scope { 233 g := fromScope.groupby 234 235 // Create the window frames based on the orderings and groupings specified. 236 argLists := make([][]opt.ScalarExpr, len(g.aggs)) 237 partitions := make([]opt.ColSet, len(g.aggs)) 238 orderings := make([]physical.OrderingChoice, len(g.aggs)) 239 filterCols := make([]opt.ColumnID, len(g.aggs)) 240 241 // Construct the pre-projection, which renders the grouping columns and the 242 // aggregate arguments, as well as any additional order by columns. 243 g.aggInScope.appendColumnsFromScope(fromScope) 244 b.constructProjectForScope(fromScope, g.aggInScope) 245 246 // Build the arguments, partitions and orderings for each aggregate. 247 for i, agg := range g.aggs { 248 argExprs := getTypedExprs(agg.Exprs) 249 250 // Build the appropriate arguments. 251 argLists[i] = b.buildWindowArgs(argExprs, i, agg.def.Name, fromScope, g.aggInScope) 252 253 // Build appropriate partitions. 254 partitions[i] = groupingColSet.Copy() 255 256 // Build appropriate orderings. 257 if !agg.isCommutative() { 258 ord := b.buildWindowOrdering(agg.OrderBy, i, agg.def.Name, fromScope, g.aggInScope) 259 orderings[i].FromOrdering(ord) 260 } 261 262 if agg.Filter != nil { 263 col := b.buildFilterCol(agg.Filter, i, agg.def.Name, fromScope, g.aggInScope) 264 filterCols[i] = col.id 265 } 266 } 267 268 // Initialize the aggregate expression. 269 aggregateExpr := g.aggInScope.expr 270 271 // frames accumulates the set of distinct window frames we're computing over 272 // so that we can group functions over the same partition and ordering. 273 frames := make([]memo.WindowExpr, 0, len(g.aggs)) 274 for i, agg := range g.aggs { 275 fn := b.constructAggregate(agg.def.Name, argLists[i]) 276 if filterCols[i] != 0 { 277 fn = b.factory.ConstructAggFilter( 278 fn, 279 b.factory.ConstructVariable(filterCols[i]), 280 ) 281 } 282 283 frameIdx := b.findMatchingFrameIndex(&frames, partitions[i], orderings[i]) 284 285 frames[frameIdx].Windows = append(frames[frameIdx].Windows, 286 b.factory.ConstructWindowsItem( 287 fn, 288 &memo.WindowsItemPrivate{ 289 Frame: windowAggregateFrame(), 290 Col: agg.col.id, 291 }, 292 ), 293 ) 294 } 295 296 for _, f := range frames { 297 aggregateExpr = b.factory.ConstructWindow(aggregateExpr, f.Windows, &f.WindowPrivate) 298 } 299 300 // Construct a grouping so the values per group are squashed down. Each of the 301 // aggregations built as window functions emit an aggregated value for each row 302 // instead of each group. To rectify this, we must 'squash' the values down by 303 // wrapping it with a GroupBy or ScalarGroupBy. 304 g.aggOutScope.expr = b.constructWindowGroup(aggregateExpr, groupingColSet, g.aggs, g.aggOutScope) 305 306 // Wrap with having filter if it exists. 307 if having != nil { 308 input := g.aggOutScope.expr.(memo.RelExpr) 309 filters := memo.FiltersExpr{b.factory.ConstructFiltersItem(having)} 310 g.aggOutScope.expr = b.factory.ConstructSelect(input, filters) 311 } 312 return g.aggOutScope 313 } 314 315 // getTypedWindowArgs returns the arguments to the window function as 316 // a []tree.TypedExpr. In the case of arguments with default values, it 317 // fills in the values if they are missing. 318 // TODO(justin): this is a bit of a hack to get around the fact that we don't 319 // have a good way to represent optional values in the opt tree, figure out 320 // a better way to do this. In particular this is bad because it results in us 321 // projecting the default argument to some window functions when we could just 322 // not do that projection. 323 func (b *Builder) getTypedWindowArgs(w *windowInfo) []tree.TypedExpr { 324 argExprs := getTypedExprs(w.Exprs) 325 326 switch w.def.Name { 327 // The second argument of {lead,lag} is 1 by default, and the third argument 328 // is NULL by default. 329 case "lead", "lag": 330 if len(argExprs) < 2 { 331 argExprs = append(argExprs, tree.NewDInt(1)) 332 } 333 if len(argExprs) < 3 { 334 null := tree.ReType(tree.DNull, argExprs[0].ResolvedType()) 335 argExprs = append(argExprs, null) 336 } 337 } 338 339 return argExprs 340 } 341 342 // buildWindowArgs builds the argExprs into a slice of memo.ScalarListExpr. 343 func (b *Builder) buildWindowArgs( 344 argExprs []tree.TypedExpr, windowIndex int, funcName string, inScope, outScope *scope, 345 ) memo.ScalarListExpr { 346 argList := make(memo.ScalarListExpr, len(argExprs)) 347 for j, a := range argExprs { 348 col := outScope.findExistingCol(a, false /* allowSideEffects */) 349 if col == nil { 350 col = b.synthesizeColumn( 351 outScope, 352 fmt.Sprintf("%s_%d_arg%d", funcName, windowIndex+1, j+1), 353 a.ResolvedType(), 354 a, 355 b.buildScalar(a, inScope, nil, nil, nil), 356 ) 357 } 358 argList[j] = b.factory.ConstructVariable(col.id) 359 } 360 return argList 361 } 362 363 // buildWindowPartition builds the appropriate partitions for window functions. 364 func (b *Builder) buildWindowPartition( 365 partitions []tree.Expr, windowIndex int, funcName string, inScope, outScope *scope, 366 ) opt.ColSet { 367 partition := make([]tree.TypedExpr, len(partitions)) 368 for i := range partition { 369 partition[i] = partitions[i].(tree.TypedExpr) 370 } 371 372 // PARTITION BY (a, b) => PARTITION BY a, b 373 var windowPartition opt.ColSet 374 cols := flattenTuples(partition) 375 for j, e := range cols { 376 col := outScope.findExistingCol(e, false /* allowSideEffects */) 377 if col == nil { 378 col = b.synthesizeColumn( 379 outScope, 380 fmt.Sprintf("%s_%d_partition_%d", funcName, windowIndex+1, j+1), 381 e.ResolvedType(), 382 e, 383 b.buildScalar(e, inScope, nil, nil, nil), 384 ) 385 } 386 windowPartition.Add(col.id) 387 } 388 return windowPartition 389 } 390 391 // buildWindowOrdering builds the appropriate orderings for window functions. 392 func (b *Builder) buildWindowOrdering( 393 orderBy tree.OrderBy, windowIndex int, funcName string, inScope, outScope *scope, 394 ) opt.Ordering { 395 ord := make(opt.Ordering, 0, len(orderBy)) 396 for j, t := range orderBy { 397 // ORDER BY (a, b) => ORDER BY a, b. 398 te := inScope.resolveType(t.Expr, types.Any) 399 cols := flattenTuples([]tree.TypedExpr{te}) 400 401 for _, e := range cols { 402 col := outScope.findExistingCol(e, false /* allowSideEffects */) 403 if col == nil { 404 col = b.synthesizeColumn( 405 outScope, 406 fmt.Sprintf("%s_%d_orderby_%d", funcName, windowIndex+1, j+1), 407 te.ResolvedType(), 408 te, 409 b.buildScalar(e, inScope, nil, nil, nil), 410 ) 411 } 412 ord = append(ord, opt.MakeOrderingColumn(col.id, t.Direction == tree.Descending)) 413 } 414 } 415 return ord 416 } 417 418 // buildFilterCol builds the filter column from the filter Expr. 419 func (b *Builder) buildFilterCol( 420 filter tree.Expr, windowIndex int, funcName string, inScope, outScope *scope, 421 ) *scopeColumn { 422 defer b.semaCtx.Properties.Restore(b.semaCtx.Properties) 423 b.semaCtx.Properties.Require("FILTER", tree.RejectSpecial) 424 425 te := inScope.resolveAndRequireType(filter, types.Bool) 426 427 col := outScope.findExistingCol(te, false /* allowSideEffects */) 428 if col == nil { 429 col = b.synthesizeColumn( 430 outScope, 431 fmt.Sprintf("%s_%d_filter", funcName, windowIndex+1), 432 te.ResolvedType(), 433 te, 434 b.buildScalar(te, inScope, nil, nil, nil), 435 ) 436 } 437 438 return col 439 } 440 441 // findMatchingFrameIndex finds a frame position to which a window of the 442 // given partition and ordering can be added to. If no such frame is found, a 443 // new one is made. 444 func (b *Builder) findMatchingFrameIndex( 445 frames *[]memo.WindowExpr, partition opt.ColSet, ordering physical.OrderingChoice, 446 ) int { 447 frameIdx := -1 448 449 // The number of window functions is probably fairly small, so do an O(n^2) 450 // loop. 451 // TODO(justin): make this faster. 452 // TODO(justin): consider coalescing frames with compatible orderings. 453 for j := range *frames { 454 if partition.Equals((*frames)[j].Partition) && 455 ordering.Equals(&(*frames)[j].Ordering) { 456 frameIdx = j 457 break 458 } 459 } 460 461 var rangeOffsetColumn opt.ColumnID 462 if len(ordering.Columns) == 1 { 463 rangeOffsetColumn = ordering.Columns[0].AnyID() 464 } 465 // If we can't reuse an existing frame, make a new one. 466 if frameIdx == -1 { 467 *frames = append(*frames, memo.WindowExpr{ 468 WindowPrivate: memo.WindowPrivate{ 469 Partition: partition, 470 Ordering: ordering, 471 RangeOffsetColumn: rangeOffsetColumn, 472 }, 473 Windows: memo.WindowsExpr{}, 474 }) 475 frameIdx = len(*frames) - 1 476 } 477 478 return frameIdx 479 } 480 481 // constructWindowGroup wraps the input window expression with an appropriate 482 // grouping so the results of each window column are squashed down. 483 // The expression may be wrapped with a projection so ensure the default NULL 484 // values of the aggregates are respected when no rows are returned. 485 func (b *Builder) constructWindowGroup( 486 input memo.RelExpr, groupingColSet opt.ColSet, aggInfos []aggregateInfo, outScope *scope, 487 ) memo.RelExpr { 488 if groupingColSet.Empty() { 489 // Construct a scalar GroupBy wrapped around the appropriate projections. 490 return b.constructScalarWindowGroup(input, groupingColSet, aggInfos, outScope) 491 } 492 493 // Construct a GroupBy using the groupingColSet. Use the ConstAgg aggregate for 494 // the window columns. 495 private := memo.GroupingPrivate{GroupingCols: groupingColSet} 496 private.Ordering.FromOrderingWithOptCols(nil, groupingColSet) 497 aggs := make(memo.AggregationsExpr, 0, len(aggInfos)) 498 for i := range aggInfos { 499 aggs = append(aggs, b.factory.ConstructAggregationsItem( 500 b.factory.ConstructConstAgg(b.factory.ConstructVariable(aggInfos[i].col.id)), 501 aggInfos[i].col.id, 502 )) 503 } 504 return b.factory.ConstructGroupBy(input, aggs, &private) 505 } 506 507 // replaceDefaultReturn constructs a case expression to apply as a projection over 508 // a ScalarGroupBy expression, that replaces the default NULL value from matchVal 509 // to replaceVal. 510 func (b *Builder) replaceDefaultReturn( 511 varExpr, matchVal, replaceVal opt.ScalarExpr, 512 ) opt.ScalarExpr { 513 return b.factory.ConstructCase( 514 memo.TrueSingleton, 515 memo.ScalarListExpr{ 516 b.factory.ConstructWhen( 517 b.factory.ConstructIs(varExpr, matchVal), 518 replaceVal, 519 ), 520 }, 521 varExpr, 522 ) 523 } 524 525 // overrideDefaultNullValue checks whether the aggregate has a predefined null 526 // value for scalar group by when no rows are returned. The default null value 527 // to be applied is also returned. 528 func (b *Builder) overrideDefaultNullValue(agg aggregateInfo) (opt.ScalarExpr, bool) { 529 switch agg.def.Name { 530 case "count", "count_rows": 531 return b.factory.ConstructConst(tree.NewDInt(0), types.Int), true 532 default: 533 return nil, false 534 } 535 } 536 537 // constructScalarWindowGroup wraps the input window expression with a scalar 538 // grouping so the results of each window column are squashed down. 539 // The expression may be wrapped with a projection so ensure the default NULL 540 // values of the aggregates are respected when no rows are returned. 541 func (b *Builder) constructScalarWindowGroup( 542 input memo.RelExpr, groupingColSet opt.ColSet, aggInfos []aggregateInfo, outScope *scope, 543 ) memo.RelExpr { 544 private := memo.GroupingPrivate{GroupingCols: groupingColSet} 545 private.Ordering.FromOrderingWithOptCols(nil, groupingColSet) 546 aggs := make(memo.AggregationsExpr, 0, len(aggInfos)) 547 548 // Create a projection here to replace the NULL values with pre-defined 549 // default values of aggregates. The projection should be of the form: 550 // 551 // CASE true WHEN aggregate_result = NULL THEN default_val ELSE aggregate_result 552 // 553 // aggregate_result above is the column created by the window function after 554 // computing an aggregate. default_val is the default value for the aggregate. 555 // Example: 556 // 557 // CASE true WHEN count = NULL THEN 0 ELSE count 558 559 // Create the projections expression. 560 projections := make(memo.ProjectionsExpr, 0, len(aggInfos)) 561 562 // Create an appropriate passthrough for the projection. 563 passthrough := input.Relational().OutputCols.Copy() 564 for i := range aggInfos { 565 varExpr := b.factory.ConstructConstAgg(b.factory.ConstructVariable(aggInfos[i].col.id)) 566 567 // If the aggregate requires a projection to potentially set a default null value 568 // a new column will be needed to be synthesized. 569 defaultNullVal, requiresProjection := b.overrideDefaultNullValue(aggInfos[i]) 570 aggregateCol := aggInfos[i].col 571 if requiresProjection { 572 aggregateCol = b.synthesizeColumn(outScope, aggregateCol.name.String(), aggregateCol.typ, aggregateCol.expr, varExpr) 573 } 574 575 aggs = append(aggs, b.factory.ConstructAggregationsItem(varExpr, aggregateCol.id)) 576 passthrough.Add(aggInfos[i].col.id) 577 578 // Add projection to replace default NULL value. 579 if requiresProjection { 580 projections = append(projections, b.factory.ConstructProjectionsItem( 581 b.replaceDefaultReturn( 582 b.factory.ConstructVariable(aggregateCol.id), 583 memo.NullSingleton, 584 defaultNullVal), 585 aggInfos[i].col.id, 586 )) 587 passthrough.Remove(aggInfos[i].col.id) 588 } 589 } 590 591 scalarAggExpr := b.factory.ConstructScalarGroupBy(input, aggs, &private) 592 if len(projections) != 0 { 593 return b.factory.ConstructProject(scalarAggExpr, projections, passthrough) 594 } 595 return scalarAggExpr 596 }