github.com/dolthub/go-mysql-server@v0.18.0/sql/rowexec/rel_iters.go (about) 1 // Copyright 2023 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package rowexec 16 17 import ( 18 "container/heap" 19 "errors" 20 "fmt" 21 "io" 22 "sort" 23 "strings" 24 25 "github.com/dolthub/jsonpath" 26 27 "github.com/dolthub/go-mysql-server/sql" 28 "github.com/dolthub/go-mysql-server/sql/expression" 29 "github.com/dolthub/go-mysql-server/sql/expression/function/aggregation" 30 "github.com/dolthub/go-mysql-server/sql/plan" 31 "github.com/dolthub/go-mysql-server/sql/types" 32 ) 33 34 type topRowsIter struct { 35 sortFields sql.SortFields 36 calcFoundRows bool 37 childIter sql.RowIter 38 limit int64 39 topRows []sql.Row 40 numFoundRows int64 41 idx int 42 } 43 44 func newTopRowsIter(s sql.SortFields, limit int64, calcFoundRows bool, child sql.RowIter, childSchemaLen int) *topRowsIter { 45 return &topRowsIter{ 46 sortFields: append(s, sql.SortField{Column: expression.NewGetField(childSchemaLen, types.Int64, "order", false)}), 47 limit: limit, 48 calcFoundRows: calcFoundRows, 49 childIter: child, 50 idx: -1, 51 } 52 } 53 54 func (i *topRowsIter) Next(ctx *sql.Context) (sql.Row, error) { 55 if i.idx == -1 { 56 err := i.computeTopRows(ctx) 57 if err != nil { 58 return nil, err 59 } 60 i.idx = 0 61 } 62 63 if i.idx >= len(i.topRows) { 64 return nil, io.EOF 65 } 66 row := i.topRows[i.idx] 67 i.idx++ 68 return row[:len(row)-1], nil 69 } 70 71 func (i *topRowsIter) Close(ctx *sql.Context) error { 72 i.topRows = nil 73 74 if i.calcFoundRows { 75 ctx.SetLastQueryInfo(sql.FoundRows, i.numFoundRows) 76 } 77 78 return i.childIter.Close(ctx) 79 } 80 81 func (i *topRowsIter) computeTopRows(ctx *sql.Context) error { 82 topRowsHeap := &expression.TopRowsHeap{ 83 expression.Sorter{ 84 SortFields: i.sortFields, 85 Rows: []sql.Row{}, 86 LastError: nil, 87 Ctx: ctx, 88 }, 89 } 90 for { 91 row, err := i.childIter.Next(ctx) 92 if err == io.EOF { 93 break 94 } 95 if err != nil { 96 return err 97 } 98 i.numFoundRows++ 99 100 row = append(row, i.numFoundRows) 101 102 heap.Push(topRowsHeap, row) 103 if int64(topRowsHeap.Len()) > i.limit { 104 heap.Pop(topRowsHeap) 105 } 106 if topRowsHeap.LastError != nil { 107 return topRowsHeap.LastError 108 } 109 } 110 111 var err error 112 i.topRows, err = topRowsHeap.Rows() 113 return err 114 } 115 116 // getInt64Value returns the int64 literal value in the expression given, or an error with the errStr given if it 117 // cannot. 118 func getInt64Value(ctx *sql.Context, expr sql.Expression) (int64, error) { 119 i, err := expr.Eval(ctx, nil) 120 if err != nil { 121 return 0, err 122 } 123 124 switch i := i.(type) { 125 case int: 126 return int64(i), nil 127 case int8: 128 return int64(i), nil 129 case int16: 130 return int64(i), nil 131 case int32: 132 return int64(i), nil 133 case int64: 134 return i, nil 135 case uint: 136 return int64(i), nil 137 case uint8: 138 return int64(i), nil 139 case uint16: 140 return int64(i), nil 141 case uint32: 142 return int64(i), nil 143 case uint64: 144 return int64(i), nil 145 default: 146 // analyzer should catch this already 147 panic(fmt.Sprintf("Unsupported type for limit %T", i)) 148 } 149 } 150 151 // windowToIter transforms a plan.Window into a series 152 // of aggregation.WindowPartitionIter and a list of output projection indexes 153 // for each window partition. 154 // TODO: make partition ordering deterministic 155 func windowToIter(w *plan.Window) ([]*aggregation.WindowPartitionIter, [][]int, error) { 156 partIdToOutputIdxs := make(map[uint64][]int, 0) 157 partIdToBlock := make(map[uint64]*aggregation.WindowPartition, 0) 158 var window *sql.WindowDefinition 159 var agg *aggregation.Aggregation 160 var fn sql.WindowFunction 161 var err error 162 // collect functions in hash map keyed by partitioning scheme 163 for i, expr := range w.SelectExprs { 164 if alias, ok := expr.(*expression.Alias); ok { 165 expr = alias.Child 166 } 167 switch e := expr.(type) { 168 case sql.Aggregation: 169 window = e.Window() 170 fn, err = e.NewWindowFunction() 171 case sql.WindowAggregation: 172 window = e.Window() 173 fn, err = e.NewWindowFunction() 174 default: 175 // non window aggregates resolve to LastAgg with empty over clause 176 window = sql.NewWindowDefinition(nil, nil, nil, "", "") 177 fn, err = aggregation.NewLast(e).NewWindowFunction() 178 } 179 if err != nil { 180 return nil, nil, err 181 } 182 agg = aggregation.NewAggregation(fn, fn.DefaultFramer()) 183 184 id, err := window.PartitionId() 185 if err != nil { 186 return nil, nil, err 187 } 188 189 if block, ok := partIdToBlock[id]; !ok { 190 if err != nil { 191 return nil, nil, err 192 } 193 partIdToBlock[id] = aggregation.NewWindowPartition( 194 window.PartitionBy, 195 window.OrderBy, 196 []*aggregation.Aggregation{agg}, 197 ) 198 partIdToOutputIdxs[id] = []int{i} 199 } else { 200 block.AddAggregation(agg) 201 partIdToOutputIdxs[id] = append(partIdToOutputIdxs[id], i) 202 } 203 } 204 205 // convert partition hash map into list 206 blockIters := make([]*aggregation.WindowPartitionIter, len(partIdToBlock)) 207 outputOrdinals := make([][]int, len(partIdToBlock)) 208 i := 0 209 for id, block := range partIdToBlock { 210 outputIdx := partIdToOutputIdxs[id] 211 blockIters[i] = aggregation.NewWindowPartitionIter(block) 212 outputOrdinals[i] = outputIdx 213 i++ 214 } 215 return blockIters, outputOrdinals, nil 216 } 217 218 type offsetIter struct { 219 skip int64 220 childIter sql.RowIter 221 } 222 223 func (i *offsetIter) Next(ctx *sql.Context) (sql.Row, error) { 224 if i.skip > 0 { 225 for i.skip > 0 { 226 _, err := i.childIter.Next(ctx) 227 if err != nil { 228 return nil, err 229 } 230 i.skip-- 231 } 232 } 233 234 row, err := i.childIter.Next(ctx) 235 if err != nil { 236 return nil, err 237 } 238 239 return row, nil 240 } 241 242 func (i *offsetIter) Close(ctx *sql.Context) error { 243 return i.childIter.Close(ctx) 244 } 245 246 type jsonTableColOpts struct { 247 name string 248 typ sql.Type 249 forOrd bool 250 exists bool 251 defErrVal interface{} 252 defEmpVal interface{} 253 errOnErr bool 254 errOnEmp bool 255 } 256 257 // jsonTableCol represents a column in a json table. 258 type jsonTableCol struct { 259 path string // if there are nested columns, this is a schema path, otherwise it is a col path 260 opts *jsonTableColOpts 261 cols []*jsonTableCol // nested columns 262 263 data []interface{} 264 err error 265 pos int 266 finished bool // exhausted all rows in data 267 currSib int 268 } 269 270 // IsSibling returns if the jsonTableCol contains multiple columns 271 func (c *jsonTableCol) IsSibling() bool { 272 return len(c.cols) != 0 273 } 274 275 // NextSibling starts at the current sibling and moves to the next unfinished sibling 276 // if there are no more unfinished siblings, it sets c.currSib to the first sibling and returns true 277 // if the c.currSib is unfinished, nothing changes 278 func (c *jsonTableCol) NextSibling() bool { 279 for i := c.currSib; i < len(c.cols); i++ { 280 if c.cols[i].IsSibling() && !c.cols[i].finished { 281 c.currSib = i 282 return false 283 } 284 } 285 c.currSib = 0 286 for i := 0; i < len(c.cols); i++ { 287 if c.cols[i].IsSibling() { 288 c.currSib = i 289 break 290 } 291 } 292 return true 293 } 294 295 // LoadData loads the data for this column from the given object and c.path 296 // LoadData will always wrap the data in a slice to ensure it is iterable 297 // Additionally, this function will set the c.currSib to the first sibling 298 func (c *jsonTableCol) LoadData(obj interface{}) { 299 var data interface{} 300 data, c.err = jsonpath.JsonPathLookup(obj, c.path) 301 if d, ok := data.([]interface{}); ok { 302 c.data = d 303 } else { 304 c.data = []interface{}{data} 305 } 306 c.pos = 0 307 308 c.NextSibling() 309 } 310 311 // Reset clears the column's data and error, and recursively resets all nested columns 312 func (c *jsonTableCol) Reset() { 313 c.data, c.err = nil, nil 314 c.finished = false 315 for _, col := range c.cols { 316 col.Reset() 317 } 318 } 319 320 // Next returns the next row for this column. 321 func (c *jsonTableCol) Next(obj interface{}, pass bool, ord int) (sql.Row, error) { 322 // nested column should recurse 323 if len(c.cols) != 0 { 324 if c.data == nil { 325 c.LoadData(obj) 326 } 327 328 var innerObj interface{} 329 if !c.finished { 330 innerObj = c.data[c.pos] 331 } 332 333 var row sql.Row 334 for i, col := range c.cols { 335 innerPass := len(col.cols) != 0 && i != c.currSib 336 rowPart, err := col.Next(innerObj, pass || innerPass, c.pos+1) 337 if err != nil { 338 return nil, err 339 } 340 row = append(row, rowPart...) 341 } 342 343 if pass { 344 return row, nil 345 } 346 347 if c.NextSibling() { 348 for _, col := range c.cols { 349 col.Reset() 350 } 351 c.pos++ 352 } 353 354 if c.pos >= len(c.data) { 355 c.finished = true 356 } 357 358 return row, nil 359 } 360 361 // this should only apply to nested columns, maybe... 362 if pass { 363 return sql.Row{nil}, nil 364 } 365 366 // FOR ORDINAL is a special case 367 if c.opts != nil && c.opts.forOrd { 368 return sql.Row{ord}, nil 369 } 370 371 // TODO: cache this? 372 val, err := jsonpath.JsonPathLookup(obj, c.path) 373 if c.opts.exists { 374 if err != nil { 375 return sql.Row{0}, nil 376 } else { 377 return sql.Row{1}, nil 378 } 379 } 380 381 // key error means empty 382 if err != nil { 383 if c.opts.errOnEmp { 384 return nil, fmt.Errorf("missing value for JSON_TABLE column '%s'", c.opts.name) 385 } 386 val = c.opts.defEmpVal 387 } 388 389 val, _, err = c.opts.typ.Convert(val) 390 if err != nil { 391 if c.opts.errOnErr { 392 return nil, err 393 } 394 val, _, err = c.opts.typ.Convert(c.opts.defErrVal) 395 if err != nil { 396 return nil, err 397 } 398 } 399 400 // Base columns are always finished 401 c.finished = true 402 return sql.Row{val}, nil 403 } 404 405 type jsonTableRowIter struct { 406 data []interface{} 407 pos int 408 cols []*jsonTableCol 409 currSib int 410 } 411 412 var _ sql.RowIter = &jsonTableRowIter{} 413 414 // NextSibling starts at the current sibling and moves to the next unfinished sibling 415 // if there are no more unfinished siblings, it resets to the first sibling 416 func (j *jsonTableRowIter) NextSibling() bool { 417 for i := j.currSib; i < len(j.cols); i++ { 418 if !j.cols[i].finished && len(j.cols[i].cols) != 0 { 419 j.currSib = i 420 return false 421 } 422 } 423 j.currSib = 0 424 for i := 0; i < len(j.cols); i++ { 425 if len(j.cols[i].cols) != 0 { 426 j.currSib = i 427 break 428 } 429 } 430 return true 431 } 432 433 func (j *jsonTableRowIter) ResetAll() { 434 for _, col := range j.cols { 435 col.Reset() 436 } 437 } 438 439 func (j *jsonTableRowIter) Next(ctx *sql.Context) (sql.Row, error) { 440 if j.pos >= len(j.data) { 441 return nil, io.EOF 442 } 443 obj := j.data[j.pos] 444 445 var row sql.Row 446 for i, col := range j.cols { 447 pass := len(col.cols) != 0 && i != j.currSib 448 rowPart, err := col.Next(obj, pass, j.pos+1) 449 if err != nil { 450 return nil, err 451 } 452 row = append(row, rowPart...) 453 } 454 455 if j.NextSibling() { 456 j.ResetAll() 457 j.pos++ 458 } 459 460 return row, nil 461 } 462 463 func (j *jsonTableRowIter) Close(ctx *sql.Context) error { 464 return nil 465 } 466 467 // orderedDistinctIter iterates the children iterator and skips all the 468 // repeated rows assuming the iterator has all rows sorted. 469 type orderedDistinctIter struct { 470 childIter sql.RowIter 471 schema sql.Schema 472 prevRow sql.Row 473 } 474 475 func newOrderedDistinctIter(child sql.RowIter, schema sql.Schema) *orderedDistinctIter { 476 return &orderedDistinctIter{childIter: child, schema: schema} 477 } 478 479 func (di *orderedDistinctIter) Next(ctx *sql.Context) (sql.Row, error) { 480 for { 481 row, err := di.childIter.Next(ctx) 482 if err != nil { 483 return nil, err 484 } 485 486 if di.prevRow != nil { 487 ok, err := di.prevRow.Equals(row, di.schema) 488 if err != nil { 489 return nil, err 490 } 491 492 if ok { 493 continue 494 } 495 } 496 497 di.prevRow = row 498 return row, nil 499 } 500 } 501 502 func (di *orderedDistinctIter) Close(ctx *sql.Context) error { 503 return di.childIter.Close(ctx) 504 } 505 506 type projectIter struct { 507 p []sql.Expression 508 childIter sql.RowIter 509 } 510 511 func (i *projectIter) Next(ctx *sql.Context) (sql.Row, error) { 512 childRow, err := i.childIter.Next(ctx) 513 if err != nil { 514 return nil, err 515 } 516 517 return ProjectRow(ctx, i.p, childRow) 518 } 519 520 func (i *projectIter) Close(ctx *sql.Context) error { 521 return i.childIter.Close(ctx) 522 } 523 524 // ProjectRow evaluates a set of projections. 525 func ProjectRow( 526 ctx *sql.Context, 527 projections []sql.Expression, 528 row sql.Row, 529 ) (sql.Row, error) { 530 var secondPass []int 531 var fields sql.Row 532 for i, expr := range projections { 533 // Default values that are expressions may reference other fields, thus they must evaluate after all other exprs. 534 // Also default expressions may not refer to other columns that come after them if they also have a default expr. 535 // This ensures that all columns referenced by expressions will have already been evaluated. 536 // Since literals do not reference other columns, they're evaluated on the first pass. 537 defaultVal, isDefaultVal := defaultValFromProjectExpr(expr) 538 if isDefaultVal && !defaultVal.IsLiteral() { 539 fields = append(fields, nil) 540 secondPass = append(secondPass, i) 541 continue 542 } 543 f, fErr := expr.Eval(ctx, row) 544 if fErr != nil { 545 return nil, fErr 546 } 547 f = normalizeNegativeZeros(f) 548 fields = append(fields, f) 549 } 550 for _, index := range secondPass { 551 field, err := projections[index].Eval(ctx, fields) 552 if err != nil { 553 return nil, err 554 } 555 field = normalizeNegativeZeros(field) 556 fields[index] = field 557 } 558 return sql.NewRow(fields...), nil 559 } 560 561 func defaultValFromProjectExpr(e sql.Expression) (*sql.ColumnDefaultValue, bool) { 562 if defaultVal, ok := e.(*expression.Wrapper); ok { 563 e = defaultVal.Unwrap() 564 } 565 if defaultVal, ok := e.(*sql.ColumnDefaultValue); ok { 566 return defaultVal, true 567 } 568 569 return nil, false 570 } 571 572 func defaultValFromSetExpression(e sql.Expression) (*sql.ColumnDefaultValue, bool) { 573 if sf, ok := e.(*expression.SetField); ok { 574 return defaultValFromProjectExpr(sf.RightChild) 575 } 576 return nil, false 577 } 578 579 // normalizeNegativeZeros converts negative zero into positive zero. 580 // We do this so that floats and decimals have the same representation when displayed to the user. 581 func normalizeNegativeZeros(val interface{}) interface{} { 582 // Golang doesn't have a negative zero literal, but negative zero compares equal to zero. 583 if val == float32(0) { 584 return float32(0) 585 } 586 if val == float64(0) { 587 return float64(0) 588 } 589 return val 590 } 591 592 // TODO a queue is probably more optimal 593 type recursiveTableIter struct { 594 pos int 595 buf []sql.Row 596 } 597 598 var _ sql.RowIter = (*recursiveTableIter)(nil) 599 600 func (r *recursiveTableIter) Next(ctx *sql.Context) (sql.Row, error) { 601 if r.buf == nil || r.pos >= len(r.buf) { 602 return nil, io.EOF 603 } 604 r.pos++ 605 return r.buf[r.pos-1], nil 606 } 607 608 func (r *recursiveTableIter) Close(ctx *sql.Context) error { 609 r.buf = nil 610 return nil 611 } 612 613 func setUserVar(ctx *sql.Context, userVar *expression.UserVar, right sql.Expression, row sql.Row) error { 614 val, err := right.Eval(ctx, row) 615 if err != nil { 616 return err 617 } 618 typ := types.ApproximateTypeFromValue(val) 619 620 err = ctx.SetUserVariable(ctx, userVar.Name, val, typ) 621 if err != nil { 622 return err 623 } 624 return nil 625 } 626 627 func setSystemVar(ctx *sql.Context, sysVar *expression.SystemVar, right sql.Expression, row sql.Row) error { 628 val, err := right.Eval(ctx, row) 629 if err != nil { 630 return err 631 } 632 switch sysVar.Scope { 633 case sql.SystemVariableScope_Global: 634 err = sql.SystemVariables.SetGlobal(sysVar.Name, val) 635 if err != nil { 636 return err 637 } 638 case sql.SystemVariableScope_Session: 639 err = ctx.SetSessionVariable(ctx, sysVar.Name, val) 640 if err != nil { 641 return err 642 } 643 case sql.SystemVariableScope_Persist: 644 persistSess, ok := ctx.Session.(sql.PersistableSession) 645 if !ok { 646 return sql.ErrSessionDoesNotSupportPersistence.New() 647 } 648 err = persistSess.PersistGlobal(sysVar.Name, val) 649 if err != nil { 650 return err 651 } 652 err = sql.SystemVariables.SetGlobal(sysVar.Name, val) 653 if err != nil { 654 return err 655 } 656 case sql.SystemVariableScope_PersistOnly: 657 persistSess, ok := ctx.Session.(sql.PersistableSession) 658 if !ok { 659 return sql.ErrSessionDoesNotSupportPersistence.New() 660 } 661 err = persistSess.PersistGlobal(sysVar.Name, val) 662 if err != nil { 663 return err 664 } 665 case sql.SystemVariableScope_ResetPersist: 666 // TODO: add parser support for RESET PERSIST 667 persistSess, ok := ctx.Session.(sql.PersistableSession) 668 if !ok { 669 return sql.ErrSessionDoesNotSupportPersistence.New() 670 } 671 if sysVar.Name == "" { 672 err = persistSess.RemoveAllPersistedGlobals() 673 } 674 err = persistSess.RemovePersistedGlobal(sysVar.Name) 675 if err != nil { 676 return err 677 } 678 default: // should never be hit 679 return fmt.Errorf("unable to set `%s` due to unknown scope `%v`", sysVar.Name, sysVar.Scope) 680 } 681 // Setting `character_set_connection`, regardless of how it is set (directly or through SET NAMES) will also set 682 // `collation_connection` to the default collation for the given character set. 683 if strings.ToLower(sysVar.Name) == "character_set_connection" { 684 newSysVar := &expression.SystemVar{ 685 Name: "collation_connection", 686 Scope: sysVar.Scope, 687 } 688 if val == nil { 689 err = setSystemVar(ctx, newSysVar, expression.NewLiteral("", types.LongText), row) 690 if err != nil { 691 return err 692 } 693 } else { 694 valStr, ok := val.(string) 695 if !ok { 696 return sql.ErrInvalidSystemVariableValue.New("collation_connection", val) 697 } 698 charset, err := sql.ParseCharacterSet(valStr) 699 if err != nil { 700 return err 701 } 702 charset = charset 703 err = setSystemVar(ctx, newSysVar, expression.NewLiteral(charset.DefaultCollation().Name(), types.LongText), row) 704 if err != nil { 705 return err 706 } 707 } 708 } 709 return nil 710 } 711 712 // Applies the update expressions given to the row given, returning the new resultant row. 713 func applyUpdateExpressions(ctx *sql.Context, updateExprs []sql.Expression, row sql.Row) (sql.Row, error) { 714 var ok bool 715 prev := row 716 for _, updateExpr := range updateExprs { 717 val, err := updateExpr.Eval(ctx, prev) 718 if err != nil { 719 return nil, err 720 } 721 prev, ok = val.(sql.Row) 722 if !ok { 723 return nil, plan.ErrUpdateUnexpectedSetResult.New(val) 724 } 725 } 726 return prev, nil 727 } 728 729 // declareVariablesIter is the sql.RowIter of *DeclareVariables. 730 type declareVariablesIter struct { 731 *plan.DeclareVariables 732 row sql.Row 733 } 734 735 var _ sql.RowIter = (*declareVariablesIter)(nil) 736 737 // Next implements the interface sql.RowIter. 738 func (d *declareVariablesIter) Next(ctx *sql.Context) (sql.Row, error) { 739 defaultVal, err := d.DefaultVal.Eval(ctx, d.row) 740 if err != nil { 741 return nil, err 742 } 743 for _, varName := range d.Names { 744 if err := d.Pref.InitializeVariable(varName, d.Type, defaultVal); err != nil { 745 return nil, err 746 } 747 } 748 return nil, io.EOF 749 } 750 751 // Close implements the interface sql.RowIter. 752 func (d *declareVariablesIter) Close(ctx *sql.Context) error { 753 return nil 754 } 755 756 // declareHandlerIter is the sql.RowIter of *DeclareHandler. 757 type declareHandlerIter struct { 758 *plan.DeclareHandler 759 } 760 761 var _ sql.RowIter = (*declareHandlerIter)(nil) 762 763 // Next implements the interface sql.RowIter. 764 func (d *declareHandlerIter) Next(ctx *sql.Context) (sql.Row, error) { 765 d.Pref.InitializeHandler(d.Statement, d.Action, d.Condition) 766 return nil, io.EOF 767 } 768 769 // Close implements the interface sql.RowIter. 770 func (d *declareHandlerIter) Close(ctx *sql.Context) error { 771 return nil 772 } 773 774 const cteRecursionLimit = 10001 775 776 // recursiveCteIter exhaustively executes a recursive 777 // relation [rec] populated by an [init] base case. 778 // Refer to RecursiveCte for more details. 779 type recursiveCteIter struct { 780 // base sql.Project 781 init sql.Node 782 // recursive sql.Project 783 rec sql.Node 784 // anchor to recursive table to repopulate with [temp] 785 working *plan.RecursiveTable 786 // true if UNION, false if UNION ALL 787 deduplicate bool 788 // parent iter initialization state 789 row sql.Row 790 791 // active iterator, either [init].RowIter or [rec].RowIter 792 iter sql.RowIter 793 // number of recursive iterations finished 794 cycle int 795 // buffer to collect intermediate results for next recursion 796 temp []sql.Row 797 // duplicate lookup if [deduplicated] set 798 cache sql.KeyValueCache 799 b *BaseBuilder 800 } 801 802 var _ sql.RowIter = (*recursiveCteIter)(nil) 803 804 // Next implements sql.RowIter 805 func (r *recursiveCteIter) Next(ctx *sql.Context) (sql.Row, error) { 806 if r.iter == nil { 807 // start with [Init].RowIter 808 var err error 809 if r.deduplicate { 810 r.cache = sql.NewMapCache() 811 812 } 813 r.iter, err = r.b.buildNodeExec(ctx, r.init, r.row) 814 815 if err != nil { 816 return nil, err 817 } 818 } 819 820 var row sql.Row 821 for { 822 var err error 823 row, err = r.iter.Next(ctx) 824 if errors.Is(err, io.EOF) && len(r.temp) > 0 { 825 // reset [Rec].RowIter 826 err = r.resetIter(ctx) 827 if err != nil { 828 return nil, err 829 } 830 continue 831 } else if err != nil { 832 return nil, err 833 } 834 835 var key uint64 836 if r.deduplicate { 837 key, _ = sql.HashOf(row) 838 if k, _ := r.cache.Get(key); k != nil { 839 // skip duplicate 840 continue 841 } 842 } 843 r.store(row, key) 844 if err != nil { 845 return nil, err 846 } 847 break 848 } 849 return row, nil 850 } 851 852 // store saves a row to the [temp] buffer, and hashes if [deduplicated] = true 853 func (r *recursiveCteIter) store(row sql.Row, key uint64) { 854 if r.deduplicate { 855 r.cache.Put(key, struct{}{}) 856 } 857 r.temp = append(r.temp, row) 858 return 859 } 860 861 // resetIter creates a new [Rec].RowIter after refreshing the [working] RecursiveTable 862 func (r *recursiveCteIter) resetIter(ctx *sql.Context) error { 863 if len(r.temp) == 0 { 864 return io.EOF 865 } 866 r.cycle++ 867 if r.cycle > cteRecursionLimit { 868 return sql.ErrCteRecursionLimitExceeded.New() 869 } 870 871 if r.working != nil { 872 r.working.Buf = r.temp 873 r.temp = make([]sql.Row, 0) 874 } 875 876 err := r.iter.Close(ctx) 877 if err != nil { 878 return err 879 } 880 r.iter, err = r.b.buildNodeExec(ctx, r.rec, r.row) 881 if err != nil { 882 return err 883 } 884 return nil 885 } 886 887 // Close implements sql.RowIter 888 func (r *recursiveCteIter) Close(ctx *sql.Context) error { 889 r.working.Buf = nil 890 r.temp = nil 891 if r.iter != nil { 892 return r.iter.Close(ctx) 893 } 894 return nil 895 } 896 897 type limitIter struct { 898 calcFoundRows bool 899 currentPos int64 900 childIter sql.RowIter 901 limit int64 902 } 903 904 func (li *limitIter) Next(ctx *sql.Context) (sql.Row, error) { 905 if li.currentPos >= li.limit { 906 // If we were asked to calc all found rows, then when we are past the limit we iterate over the rest of the 907 // result set to count it 908 if li.calcFoundRows { 909 for { 910 _, err := li.childIter.Next(ctx) 911 if err != nil { 912 return nil, err 913 } 914 li.currentPos++ 915 } 916 } 917 918 return nil, io.EOF 919 } 920 921 childRow, err := li.childIter.Next(ctx) 922 if err != nil { 923 return nil, err 924 } 925 li.currentPos++ 926 927 return childRow, nil 928 } 929 930 func (li *limitIter) Close(ctx *sql.Context) error { 931 err := li.childIter.Close(ctx) 932 if err != nil { 933 return err 934 } 935 936 if li.calcFoundRows { 937 ctx.SetLastQueryInfo(sql.FoundRows, li.currentPos) 938 } 939 return nil 940 } 941 942 type sortIter struct { 943 sortFields sql.SortFields 944 childIter sql.RowIter 945 sortedRows []sql.Row 946 idx int 947 } 948 949 var _ sql.RowIter = (*sortIter)(nil) 950 951 func newSortIter(s sql.SortFields, child sql.RowIter) *sortIter { 952 return &sortIter{ 953 sortFields: s, 954 childIter: child, 955 idx: -1, 956 } 957 } 958 959 func (i *sortIter) Next(ctx *sql.Context) (sql.Row, error) { 960 if i.idx == -1 { 961 err := i.computeSortedRows(ctx) 962 if err != nil { 963 return nil, err 964 } 965 i.idx = 0 966 } 967 968 if i.idx >= len(i.sortedRows) { 969 return nil, io.EOF 970 } 971 row := i.sortedRows[i.idx] 972 i.idx++ 973 return row, nil 974 } 975 976 func (i *sortIter) Close(ctx *sql.Context) error { 977 i.sortedRows = nil 978 return i.childIter.Close(ctx) 979 } 980 981 func (i *sortIter) computeSortedRows(ctx *sql.Context) error { 982 cache, dispose := ctx.Memory.NewRowsCache() 983 defer dispose() 984 985 for { 986 row, err := i.childIter.Next(ctx) 987 988 if err == io.EOF { 989 break 990 } 991 if err != nil { 992 return err 993 } 994 995 if err := cache.Add(row); err != nil { 996 return err 997 } 998 } 999 1000 rows := cache.Get() 1001 sorter := &expression.Sorter{ 1002 SortFields: i.sortFields, 1003 Rows: rows, 1004 LastError: nil, 1005 Ctx: ctx, 1006 } 1007 sort.Stable(sorter) 1008 if sorter.LastError != nil { 1009 return sorter.LastError 1010 } 1011 i.sortedRows = rows 1012 return nil 1013 } 1014 1015 // distinctIter keeps track of the hashes of all rows that have been emitted. 1016 // It does not emit any rows whose hashes have been seen already. 1017 // TODO: come up with a way to use less memory than keeping all hashes in memory. 1018 // Even though they are just 64-bit integers, this could be a problem in large 1019 // result sets. 1020 type distinctIter struct { 1021 childIter sql.RowIter 1022 seen sql.KeyValueCache 1023 dispose sql.DisposeFunc 1024 } 1025 1026 func newDistinctIter(ctx *sql.Context, child sql.RowIter) *distinctIter { 1027 cache, dispose := ctx.Memory.NewHistoryCache() 1028 return &distinctIter{ 1029 childIter: child, 1030 seen: cache, 1031 dispose: dispose, 1032 } 1033 } 1034 1035 func (di *distinctIter) Next(ctx *sql.Context) (sql.Row, error) { 1036 for { 1037 row, err := di.childIter.Next(ctx) 1038 if err != nil { 1039 if err == io.EOF { 1040 di.Dispose() 1041 } 1042 return nil, err 1043 } 1044 1045 hash, err := sql.HashOf(row) 1046 if err != nil { 1047 return nil, err 1048 } 1049 1050 if _, err := di.seen.Get(hash); err == nil { 1051 continue 1052 } 1053 1054 if err := di.seen.Put(hash, struct{}{}); err != nil { 1055 return nil, err 1056 } 1057 1058 return row, nil 1059 } 1060 } 1061 1062 func (di *distinctIter) Close(ctx *sql.Context) error { 1063 di.Dispose() 1064 return di.childIter.Close(ctx) 1065 } 1066 1067 func (di *distinctIter) Dispose() { 1068 if di.dispose != nil { 1069 di.dispose() 1070 } 1071 } 1072 1073 type unionIter struct { 1074 cur sql.RowIter 1075 nextIter func(ctx *sql.Context) (sql.RowIter, error) 1076 } 1077 1078 func (ui *unionIter) Next(ctx *sql.Context) (sql.Row, error) { 1079 res, err := ui.cur.Next(ctx) 1080 if err == io.EOF { 1081 if ui.nextIter == nil { 1082 return nil, io.EOF 1083 } 1084 err = ui.cur.Close(ctx) 1085 if err != nil { 1086 return nil, err 1087 } 1088 ui.cur, err = ui.nextIter(ctx) 1089 ui.nextIter = nil 1090 if err != nil { 1091 return nil, err 1092 } 1093 return ui.cur.Next(ctx) 1094 } 1095 return res, err 1096 } 1097 1098 func (ui *unionIter) Close(ctx *sql.Context) error { 1099 if ui.cur != nil { 1100 return ui.cur.Close(ctx) 1101 } else { 1102 return nil 1103 } 1104 } 1105 1106 type intersectIter struct { 1107 lIter, rIter sql.RowIter 1108 cached bool 1109 cache map[uint64]int 1110 } 1111 1112 func (ii *intersectIter) Next(ctx *sql.Context) (sql.Row, error) { 1113 if !ii.cached { 1114 ii.cache = make(map[uint64]int) 1115 for { 1116 res, err := ii.rIter.Next(ctx) 1117 if err != nil && err != io.EOF { 1118 return nil, err 1119 } 1120 1121 hash, herr := sql.HashOf(res) 1122 if herr != nil { 1123 return nil, herr 1124 } 1125 if _, ok := ii.cache[hash]; !ok { 1126 ii.cache[hash] = 0 1127 } 1128 ii.cache[hash]++ 1129 1130 if err == io.EOF { 1131 break 1132 } 1133 } 1134 ii.cached = true 1135 } 1136 1137 for { 1138 res, err := ii.lIter.Next(ctx) 1139 if err != nil { 1140 return nil, err 1141 } 1142 1143 hash, herr := sql.HashOf(res) 1144 if herr != nil { 1145 return nil, herr 1146 } 1147 if _, ok := ii.cache[hash]; !ok { 1148 continue 1149 } 1150 if ii.cache[hash] <= 0 { 1151 continue 1152 } 1153 ii.cache[hash]-- 1154 1155 return res, nil 1156 } 1157 } 1158 1159 func (ii *intersectIter) Close(ctx *sql.Context) error { 1160 if ii.lIter != nil { 1161 if err := ii.lIter.Close(ctx); err != nil { 1162 return err 1163 } 1164 } 1165 if ii.rIter != nil { 1166 if err := ii.rIter.Close(ctx); err != nil { 1167 return err 1168 } 1169 } 1170 return nil 1171 } 1172 1173 type exceptIter struct { 1174 lIter, rIter sql.RowIter 1175 cached bool 1176 cache map[uint64]int 1177 } 1178 1179 func (ei *exceptIter) Next(ctx *sql.Context) (sql.Row, error) { 1180 if !ei.cached { 1181 ei.cache = make(map[uint64]int) 1182 for { 1183 res, err := ei.rIter.Next(ctx) 1184 if err != nil && err != io.EOF { 1185 return nil, err 1186 } 1187 1188 hash, herr := sql.HashOf(res) 1189 if herr != nil { 1190 return nil, herr 1191 } 1192 if _, ok := ei.cache[hash]; !ok { 1193 ei.cache[hash] = 0 1194 } 1195 ei.cache[hash]++ 1196 1197 if err == io.EOF { 1198 break 1199 } 1200 } 1201 ei.cached = true 1202 } 1203 1204 for { 1205 res, err := ei.lIter.Next(ctx) 1206 if err != nil { 1207 return nil, err 1208 } 1209 1210 hash, herr := sql.HashOf(res) 1211 if herr != nil { 1212 return nil, herr 1213 } 1214 if _, ok := ei.cache[hash]; !ok { 1215 return res, nil 1216 } 1217 if ei.cache[hash] <= 0 { 1218 return res, nil 1219 } 1220 ei.cache[hash]-- 1221 } 1222 } 1223 1224 func (ei *exceptIter) Close(ctx *sql.Context) error { 1225 if ei.lIter != nil { 1226 if err := ei.lIter.Close(ctx); err != nil { 1227 return err 1228 } 1229 } 1230 if ei.rIter != nil { 1231 if err := ei.rIter.Close(ctx); err != nil { 1232 return err 1233 } 1234 } 1235 return nil 1236 }