github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/aggregation/window_framer.go (about) 1 // Copyright 2022 DoltHub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package aggregation 16 17 import ( 18 "errors" 19 "io" 20 21 ast "github.com/dolthub/vitess/go/vt/sqlparser" 22 sqlerr "gopkg.in/src-d/go-errors.v1" 23 24 "github.com/dolthub/go-mysql-server/sql" 25 "github.com/dolthub/go-mysql-server/sql/expression" 26 ) 27 28 //go:generate go run ../../../../optgen/cmd/optgen/main.go -out window_framer.og.go -pkg aggregation framer window_framer.go 29 30 var ErrPartitionNotSet = errors.New("attempted to general a window frame interval before framer partition was set") 31 var ErrRangeIntervalTypeMismatch = errors.New("range bound type must match the order by expression type") 32 var ErrRangeInvalidOrderBy = sqlerr.NewKind("a range's order by must be one expression; found: %d") 33 34 var _ sql.WindowFramer = (*PartitionFramer)(nil) 35 var _ sql.WindowFramer = (*GroupByFramer)(nil) 36 var _ sql.WindowFramer = (*rowFramerBase)(nil) 37 var _ sql.WindowFramer = (*rangeFramerBase)(nil) 38 var _ sql.WindowFramer = (*PeerGroupFramer)(nil) 39 40 // NewUnboundedPrecedingToCurrentRowFramer generates sql.WindowInterval 41 // from the first row in a partition to the current row. 42 // 43 // Ex: partition = [0, 1, 2, 3, 4, 5] 44 // => 45 // frames: {0,0}, {0,1}, {0,2}, {0,3}, {0,4}, {0,5} 46 // rows: [0], [0,1], [0,1,2], [1,2,3,4], [1,2,3,4], [1,2,3,4,5] 47 func NewUnboundedPrecedingToCurrentRowFramer() *RowsUnboundedPrecedingToCurrentRowFramer { 48 return &RowsUnboundedPrecedingToCurrentRowFramer{ 49 rowFramerBase{ 50 unboundedPreceding: true, 51 endCurrentRow: false, 52 }, 53 } 54 } 55 56 type PartitionFramer struct { 57 idx int 58 partitionStart, partitionEnd int 59 60 followOffset, precOffset int 61 frameStart, frameEnd int 62 partitionSet bool 63 } 64 65 func NewPartitionFramer() *PartitionFramer { 66 return &PartitionFramer{ 67 idx: -1, 68 frameEnd: -1, 69 frameStart: -1, 70 partitionStart: -1, 71 partitionEnd: -1, 72 } 73 } 74 75 func (f *PartitionFramer) NewFramer(interval sql.WindowInterval) (sql.WindowFramer, error) { 76 return &PartitionFramer{ 77 idx: interval.Start, 78 frameEnd: interval.End, 79 frameStart: interval.Start, 80 partitionStart: interval.Start, 81 partitionEnd: interval.End, 82 partitionSet: true, 83 }, nil 84 } 85 86 func (f *PartitionFramer) Next(ctx *sql.Context, buffer sql.WindowBuffer) (sql.WindowInterval, error) { 87 if !f.partitionSet { 88 return sql.WindowInterval{}, io.EOF 89 } 90 if f.idx == 0 || (0 < f.idx && f.idx < f.partitionEnd) { 91 f.idx++ 92 return f.Interval() 93 } 94 return sql.WindowInterval{}, io.EOF 95 } 96 97 func (f *PartitionFramer) FirstIdx() int { 98 return f.frameStart 99 } 100 101 func (f *PartitionFramer) LastIdx() int { 102 return f.frameEnd 103 } 104 105 func (f *PartitionFramer) Interval() (sql.WindowInterval, error) { 106 if !f.partitionSet { 107 return sql.WindowInterval{}, ErrPartitionNotSet 108 } 109 return sql.WindowInterval{Start: f.frameStart, End: f.frameEnd}, nil 110 } 111 112 func (f *PartitionFramer) SlidingInterval(ctx sql.Context) (sql.WindowInterval, sql.WindowInterval, sql.WindowInterval) { 113 panic("implement me") 114 } 115 116 func (f *PartitionFramer) Close() { 117 panic("implement me") 118 } 119 120 func NewGroupByFramer() *GroupByFramer { 121 return &GroupByFramer{ 122 frameEnd: -1, 123 frameStart: -1, 124 partitionStart: -1, 125 partitionEnd: -1, 126 } 127 } 128 129 type GroupByFramer struct { 130 evaluated bool 131 partitionStart, partitionEnd int 132 133 frameStart, frameEnd int 134 partitionSet bool 135 } 136 137 func (f *GroupByFramer) NewFramer(interval sql.WindowInterval) (sql.WindowFramer, error) { 138 return &GroupByFramer{ 139 evaluated: false, 140 frameEnd: interval.End, 141 frameStart: interval.Start, 142 partitionStart: interval.Start, 143 partitionEnd: interval.End, 144 partitionSet: true, 145 }, nil 146 } 147 148 func (f *GroupByFramer) Next(ctx *sql.Context, buffer sql.WindowBuffer) (sql.WindowInterval, error) { 149 if !f.partitionSet { 150 return sql.WindowInterval{}, io.EOF 151 } 152 if !f.evaluated { 153 f.evaluated = true 154 return f.Interval() 155 } 156 return sql.WindowInterval{}, io.EOF 157 } 158 159 func (f *GroupByFramer) FirstIdx() int { 160 return f.frameStart 161 } 162 163 func (f *GroupByFramer) LastIdx() int { 164 return f.frameEnd 165 } 166 167 func (f *GroupByFramer) Interval() (sql.WindowInterval, error) { 168 if !f.partitionSet { 169 return sql.WindowInterval{}, ErrPartitionNotSet 170 } 171 return sql.WindowInterval{Start: f.frameStart, End: f.frameEnd}, nil 172 } 173 174 func (f *GroupByFramer) SlidingInterval(ctx sql.Context) (sql.WindowInterval, sql.WindowInterval, sql.WindowInterval) { 175 panic("implement me") 176 } 177 178 // rowFramerBase is a sql.WindowFramer iterator that tracks 179 // index frames in a sql.WindowBuffer using integer offsets. 180 // Only a subset of bound conditions will be set for a given 181 // framer implementation, one start and one end bound. 182 // 183 // Ex: startCurrentRow = true; endNFollowing = 1; 184 // buffer = [0, 1, 2, 3, 4, 5]; 185 // => 186 // pos: 0->0 1->1 2->2 3->3 4->4 5->5 187 // frame: {0,2}, {1,3}, {2,4}, {3,5}, {4,6}, {4,5} 188 // rows: [0,1], [1,2], [2,3], [3,4], [4,5], [5] 189 type rowFramerBase struct { 190 idx int 191 partitionStart int 192 partitionEnd int 193 frameStart int 194 frameEnd int 195 partitionSet bool 196 197 // add [startOffset] to current [idx] to find start index 198 // is set unless [unboundedPreceding] is true 199 startOffset int 200 // add [endOffset] to current [idx] to find end index 201 // is set unless [unboundedFollowing] is true 202 endOffset int 203 204 // optional start fields; one is set 205 startCurrentRow bool 206 startNPreceding int 207 startNFollowing int 208 unboundedPreceding bool 209 210 // optional end fields; one is set 211 endCurrentRow bool 212 unboundedFollowing bool 213 endNPreceding int 214 endNFollowing int 215 } 216 217 func (f *rowFramerBase) NewFramer(interval sql.WindowInterval) (sql.WindowFramer, error) { 218 var startOffset int 219 switch { 220 case f.startNPreceding != 0: 221 startOffset = -f.startNPreceding 222 case f.startNFollowing != 0: 223 startOffset = f.startNFollowing 224 case f.startCurrentRow: 225 startOffset = 0 226 } 227 228 var endOffset int 229 switch { 230 case f.endNPreceding != 0: 231 endOffset = -f.endNPreceding 232 case f.endNFollowing != 0: 233 endOffset = f.endNFollowing 234 case f.endCurrentRow: 235 endOffset = 0 236 } 237 238 return &rowFramerBase{ 239 idx: interval.Start, 240 partitionStart: interval.Start, 241 partitionEnd: interval.End, 242 frameStart: -1, 243 frameEnd: -1, 244 partitionSet: true, 245 // pass through parent state 246 unboundedPreceding: f.unboundedPreceding, 247 unboundedFollowing: f.unboundedFollowing, 248 startCurrentRow: f.startCurrentRow, 249 endCurrentRow: f.endCurrentRow, 250 startNPreceding: f.startNPreceding, 251 startNFollowing: f.startNFollowing, 252 endNPreceding: f.endNPreceding, 253 endNFollowing: f.endNFollowing, 254 // row specific 255 startOffset: startOffset, 256 endOffset: endOffset, 257 }, nil 258 } 259 260 func (f *rowFramerBase) Next(ctx *sql.Context, buffer sql.WindowBuffer) (sql.WindowInterval, error) { 261 if f.idx != 0 && f.idx >= f.partitionEnd || !f.partitionSet { 262 return sql.WindowInterval{}, io.EOF 263 } 264 265 if f.partitionEnd == 0 { 266 return sql.WindowInterval{}, io.EOF 267 } 268 269 newStart := f.idx + f.startOffset 270 if f.unboundedPreceding || newStart < f.partitionStart { 271 newStart = f.partitionStart 272 } 273 274 newEnd := f.idx + f.endOffset + 1 275 if f.unboundedFollowing || newEnd > f.partitionEnd { 276 newEnd = f.partitionEnd 277 } 278 279 if newStart > newEnd { 280 newStart = newEnd 281 } 282 283 f.frameStart = newStart 284 f.frameEnd = newEnd 285 286 f.idx++ 287 return f.Interval() 288 } 289 290 func (f *rowFramerBase) FirstIdx() int { 291 return f.frameStart 292 } 293 294 func (f *rowFramerBase) LastIdx() int { 295 return f.frameEnd 296 } 297 298 func (f *rowFramerBase) Interval() (sql.WindowInterval, error) { 299 if !f.partitionSet { 300 return sql.WindowInterval{}, ErrPartitionNotSet 301 } 302 return sql.WindowInterval{Start: f.frameStart, End: f.frameEnd}, nil 303 } 304 305 // rangeFramerBase is a sql.WindowFramer iterator that tracks 306 // value ranges in a sql.WindowBuffer using bound 307 // conditions on the order by [orderBy] column. Only a subset of 308 // bound conditions will be set for a given framer implementation, 309 // one start and one end bound. 310 // 311 // Ex: startCurrentRow = true; endNFollowing = 2; orderBy = x; 312 // -> startInclusion = (x), endInclusion = (x+2) 313 // buffer = [0, 1, 2, 4, 4, 5]; 314 // => 315 // pos: 0->0 1->1 2->2 3->4 4->4 5->5 316 // frame: {0,3}, {1,3}, {2,3}, {3,5}, {3,5}, {4,5} 317 // rows: [0,1,2], [1,2], [2], [4,4,5], [4,4,5], [5] 318 type rangeFramerBase struct { 319 idx int 320 partitionStart, partitionEnd int 321 frameStart, frameEnd int 322 partitionSet bool 323 324 // reference expression for boundary calculation 325 orderBy sql.Expression 326 327 // boundary arithmetic on [orderBy] for range start value 328 // is set unless [unboundedPreceding] is true 329 startInclusion sql.Expression 330 // boundary arithmetic on [orderBy] for range end value 331 // is set unless [unboundedFollowing] is true 332 endInclusion sql.Expression 333 334 // optional start fields; one is set 335 startCurrentRow bool 336 unboundedPreceding bool 337 startNPreceding sql.Expression 338 startNFollowing sql.Expression 339 340 // optional end fields; one is set 341 endCurrentRow bool 342 unboundedFollowing bool 343 endNPreceding sql.Expression 344 endNFollowing sql.Expression 345 } 346 347 func (f *rangeFramerBase) NewFramer(interval sql.WindowInterval) (sql.WindowFramer, error) { 348 var startInclusion sql.Expression 349 switch { 350 case f.startCurrentRow: 351 startInclusion = f.orderBy 352 case f.startNPreceding != nil: 353 startInclusion = expression.NewArithmetic(f.orderBy, f.startNPreceding, ast.MinusStr) 354 case f.startNFollowing != nil: 355 startInclusion = expression.NewArithmetic(f.orderBy, f.startNFollowing, ast.PlusStr) 356 } 357 358 // TODO: how to validate datetime, interval pair when they aren't type comparable 359 //if startInclusion != nil && startInclusion.Type().Promote() != f.orderBy.Type().Promote() { 360 // return nil, ErrRangeIntervalTypeMismatch 361 //} 362 363 var endInclusion sql.Expression 364 switch { 365 case f.endCurrentRow: 366 endInclusion = f.orderBy 367 case f.endNPreceding != nil: 368 endInclusion = expression.NewArithmetic(f.orderBy, f.endNPreceding, ast.MinusStr) 369 case f.endNFollowing != nil: 370 endInclusion = expression.NewArithmetic(f.orderBy, f.endNFollowing, ast.PlusStr) 371 } 372 373 // TODO: how to validate datetime, interval pair when they aren't type comparable 374 //if endInclusion != nil && endInclusion.Type().Promote() != f.orderBy.Type().Promote() { 375 // return nil, ErrRangeIntervalTypeMismatch 376 //} 377 378 return &rangeFramerBase{ 379 idx: interval.Start, 380 partitionStart: interval.Start, 381 partitionEnd: interval.End, 382 frameStart: interval.Start, 383 frameEnd: interval.Start, 384 partitionSet: true, 385 // pass through parent state 386 unboundedPreceding: f.unboundedPreceding, 387 unboundedFollowing: f.unboundedFollowing, 388 startCurrentRow: f.startCurrentRow, 389 endCurrentRow: f.endCurrentRow, 390 startNPreceding: f.startNPreceding, 391 startNFollowing: f.startNFollowing, 392 endNPreceding: f.endNPreceding, 393 endNFollowing: f.endNFollowing, 394 // range specific 395 orderBy: f.orderBy, 396 startInclusion: startInclusion, 397 endInclusion: endInclusion, 398 }, nil 399 } 400 401 func (f *rangeFramerBase) Next(ctx *sql.Context, buf sql.WindowBuffer) (sql.WindowInterval, error) { 402 if f.idx != 0 && f.idx >= f.partitionEnd || !f.partitionSet { 403 return sql.WindowInterval{}, io.EOF 404 } 405 406 var err error 407 newStart := f.frameStart 408 switch { 409 case newStart < f.partitionStart, f.unboundedPreceding, f.startCurrentRow && f.orderBy == nil: 410 // Start the frame at the partition start for unbounded preceding, or for current row when no order by clause 411 // has been specified. From the MySQL docs, when an order by clause is not specified with range framing, the 412 // default frame includes all rows, since all rows in the current partition are peers when no order has been 413 // specified. 414 newStart = f.partitionStart 415 default: 416 newStart, err = findInclusionBoundary(ctx, f.idx, newStart, f.partitionEnd, f.startInclusion, f.orderBy, buf, greaterThanOrEqual) 417 if err != nil { 418 return sql.WindowInterval{}, err 419 } 420 } 421 422 newEnd := f.frameEnd 423 if newStart > newEnd { 424 newEnd = newStart 425 } 426 switch { 427 case newEnd > f.partitionEnd, f.unboundedFollowing, f.endCurrentRow && f.orderBy == nil: 428 newEnd = f.partitionEnd 429 default: 430 newEnd, err = findInclusionBoundary(ctx, f.idx, newEnd, f.partitionEnd, f.endInclusion, f.orderBy, buf, greaterThan) 431 if err != nil { 432 return sql.WindowInterval{}, err 433 } 434 } 435 436 f.idx++ 437 f.frameStart = newStart 438 f.frameEnd = newEnd 439 return f.Interval() 440 } 441 442 type stopCond int 443 444 const ( 445 unknown = -2 446 greaterThan = 1 447 greaterThanOrEqual = 0 448 ) 449 450 // findInclusionBoundary searches a sorted [buffer] for the last index satisfying 451 // the comparison: [inclusion] [stopCond] [expr]. For example, (x+2) > (x). 452 // [expr] is evaluated at the current row, [inclusion] is evaluated on the boundary 453 // candidate. This is used as a sliding window algorithm for value ranges. 454 func findInclusionBoundary(ctx *sql.Context, pos, searchStart, partitionEnd int, inclusion, expr sql.Expression, buf sql.WindowBuffer, stopCond stopCond) (int, error) { 455 cur, err := inclusion.Eval(ctx, buf[pos]) 456 if err != nil { 457 return 0, err 458 } 459 460 i := searchStart 461 cmp := unknown 462 for ; cmp < int(stopCond); i++ { 463 if i >= partitionEnd { 464 return i, nil 465 } 466 467 res, err := expr.Eval(ctx, buf[i]) 468 if err != nil { 469 return 0, err 470 } 471 472 cmp, err = expr.Type().Compare(res, cur) 473 if err != nil { 474 return 0, err 475 } 476 } 477 478 return i - 1, nil 479 } 480 481 func (f *rangeFramerBase) FirstIdx() int { 482 return f.frameStart 483 } 484 485 func (f *rangeFramerBase) LastIdx() int { 486 return f.frameEnd 487 } 488 489 func (f *rangeFramerBase) Interval() (sql.WindowInterval, error) { 490 if !f.partitionSet { 491 return sql.WindowInterval{}, ErrPartitionNotSet 492 } 493 return sql.WindowInterval{Start: f.frameStart, End: f.frameEnd}, nil 494 } 495 496 type PeerGroupFramer struct { 497 idx int 498 partitionStart, partitionEnd int 499 frameStart, frameEnd int 500 partitionSet bool 501 502 // reference for peer calculation 503 orderBy []sql.Expression 504 } 505 506 func NewPeerGroupFramer(orderBy []sql.Expression) *PeerGroupFramer { 507 return &PeerGroupFramer{ 508 frameEnd: -1, 509 frameStart: -1, 510 partitionStart: -1, 511 partitionEnd: -1, 512 orderBy: orderBy, 513 } 514 } 515 516 func (f *PeerGroupFramer) NewFramer(interval sql.WindowInterval) (sql.WindowFramer, error) { 517 return &PeerGroupFramer{ 518 idx: interval.Start, 519 partitionStart: interval.Start, 520 partitionEnd: interval.End, 521 frameStart: interval.Start, 522 frameEnd: interval.Start, 523 partitionSet: true, 524 orderBy: f.orderBy, 525 }, nil 526 } 527 528 func (f *PeerGroupFramer) Next(ctx *sql.Context, buf sql.WindowBuffer) (sql.WindowInterval, error) { 529 if f.idx != 0 && f.idx >= f.partitionEnd || !f.partitionSet { 530 return sql.WindowInterval{}, io.EOF 531 } 532 if f.idx >= f.frameEnd { 533 peerGroup, err := nextPeerGroup(ctx, f.idx, f.partitionEnd, f.orderBy, buf) 534 if err != nil { 535 return sql.WindowInterval{}, err 536 } 537 f.frameStart = peerGroup.Start 538 f.frameEnd = peerGroup.End 539 } 540 f.idx++ 541 return f.Interval() 542 } 543 544 func (f *PeerGroupFramer) FirstIdx() int { 545 return f.frameStart 546 } 547 548 func (f *PeerGroupFramer) LastIdx() int { 549 return f.frameEnd 550 } 551 552 func (f *PeerGroupFramer) Interval() (sql.WindowInterval, error) { 553 if !f.partitionSet { 554 return sql.WindowInterval{}, ErrPartitionNotSet 555 } 556 return sql.WindowInterval{Start: f.frameStart, End: f.frameEnd}, nil 557 } 558 559 // nextPeerGroup scans for a sql.WindowInterval of rows with the same value as 560 // the current row [a.pos]. This is equivalent to a partitioning algorithm, but 561 // we are using the OrderBy fields, and we stream the results. 562 // ex: [1, 2, 2, 2, 2, 3, 3, 4, 5, 5, 6] => {0,1}, {1,5}, {5,7}, {8,9}, {9,10} 563 func nextPeerGroup(ctx *sql.Context, pos, partitionEnd int, orderBy []sql.Expression, buffer sql.WindowBuffer) (sql.WindowInterval, error) { 564 if pos >= partitionEnd || pos > len(buffer) { 565 return sql.WindowInterval{}, nil 566 } 567 var row sql.Row 568 i := pos + 1 569 last := buffer[pos] 570 for i < partitionEnd { 571 row = buffer[i] 572 if newPeerGroup, err := isNewOrderByValue(ctx, orderBy, last, row); err != nil { 573 return sql.WindowInterval{}, err 574 } else if newPeerGroup { 575 break 576 } 577 i++ 578 last = row 579 } 580 return sql.WindowInterval{Start: pos, End: i}, nil 581 } 582 583 // isNewOrderByValue compares the order by columns between two rows, returning true when the last row is null or 584 // when the next row's orderBy columns are unique 585 func isNewOrderByValue(ctx *sql.Context, orderByExprs []sql.Expression, last sql.Row, row sql.Row) (bool, error) { 586 if len(last) == 0 { 587 return true, nil 588 } 589 590 lastExp, _, err := evalExprs(ctx, orderByExprs, last) 591 if err != nil { 592 return false, err 593 } 594 595 thisExp, _, err := evalExprs(ctx, orderByExprs, row) 596 if err != nil { 597 return false, err 598 } 599 600 for i := range lastExp { 601 if lastExp[i] != thisExp[i] { 602 return true, nil 603 } 604 } 605 606 return false, nil 607 }