github.com/dolthub/go-mysql-server@v0.18.0/sql/rowexec/merge_join.go (about) 1 // Copyright 2022 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package rowexec 16 17 import ( 18 "errors" 19 "io" 20 21 "github.com/dolthub/go-mysql-server/sql/plan" 22 23 "github.com/dolthub/go-mysql-server/sql" 24 "github.com/dolthub/go-mysql-server/sql/expression" 25 ) 26 27 var ErrMergeJoinExpectsComparerFilters = errors.New("merge join expects expression.Comparer filters, found: %T") 28 29 // NewMergeJoin returns a node that performs a presorted merge join on 30 // two relations. We require 1) the join filter is an equality with disjoint 31 // join attributes, 2) the free attributes for a relation are a prefix for 32 // an index that will be used to return sorted rows. 33 func NewMergeJoin(left, right sql.Node, cond sql.Expression) *plan.JoinNode { 34 return plan.NewJoin(left, right, plan.JoinTypeMerge, cond) 35 } 36 37 func NewLeftMergeJoin(left, right sql.Node, cond sql.Expression) *plan.JoinNode { 38 return plan.NewJoin(left, right, plan.JoinTypeLeftOuterMerge, cond) 39 } 40 41 func newMergeJoinIter(ctx *sql.Context, b sql.NodeExecBuilder, j *plan.JoinNode, row sql.Row) (sql.RowIter, error) { 42 l, err := b.Build(ctx, j.Left(), row) 43 if err != nil { 44 return nil, err 45 } 46 r, err := b.Build(ctx, j.Right(), row) 47 if err != nil { 48 return nil, err 49 } 50 51 fullRow := make(sql.Row, len(row)+len(j.Left().Schema())+len(j.Right().Schema())) 52 fullRow[0] = row 53 if len(row) > 0 { 54 copy(fullRow[0:], row[:]) 55 } 56 57 // a merge join's first filter provides direction information 58 // for which iter to update next 59 filters := expression.SplitConjunction(j.Filter) 60 cmp, ok := filters[0].(expression.Comparer) 61 if !ok { 62 return nil, sql.ErrMergeJoinExpectsComparerFilters.New(filters[0]) 63 } 64 65 if len(filters) == 0 { 66 return nil, sql.ErrNoJoinFilters.New() 67 } 68 69 var iter sql.RowIter = &mergeJoinIter{ 70 left: l, 71 right: r, 72 filters: filters[1:], 73 cmp: cmp, 74 typ: j.Op, 75 fullRow: fullRow, 76 scopeLen: j.ScopeLen, 77 parentLen: len(row) - j.ScopeLen, 78 leftRowLen: len(j.Left().Schema()), 79 rightRowLen: len(j.Right().Schema()), 80 } 81 return iter, nil 82 } 83 84 // mergeJoinIter alternates incrementing two RowIters, assuming 85 // rows will be provided in a sorted order given the join |expr| 86 // (see sortedIndexScanForTableCol). Extra join |filters| that do 87 // not provide a directional ordering signal for index iteration 88 // are evaluated separately. 89 type mergeJoinIter struct { 90 // cmp is a directional indicator for row iter increments 91 cmp expression.Comparer 92 // filters is the remaining set of join conditions 93 filters []sql.Expression 94 left sql.RowIter 95 right sql.RowIter 96 fullRow sql.Row 97 98 // match lookahead buffers and state tracking (private to match) 99 rightBuf []sql.Row 100 bufI int 101 rightPeek sql.Row 102 leftPeek sql.Row 103 rightDone bool 104 leftDone bool 105 106 // matchIncLeft indicates whether the most recent |i.incMatch| 107 // call incremented the left row. 108 matchIncLeft bool 109 // leftMatched indicates whether the current left in |i.fullRow| 110 // has satisfied the join condition. 111 leftMatched bool 112 113 // lifecycle maintenance 114 init bool 115 leftExhausted bool 116 rightExhausted bool 117 118 typ plan.JoinType 119 scopeLen int 120 leftRowLen int 121 rightRowLen int 122 parentLen int 123 } 124 125 var _ sql.RowIter = (*mergeJoinIter)(nil) 126 127 func (i *mergeJoinIter) sel(ctx *sql.Context, row sql.Row) (bool, error) { 128 for _, f := range i.filters { 129 res, err := sql.EvaluateCondition(ctx, f, row) 130 if err != nil { 131 return false, err 132 } 133 134 if !sql.IsTrue(res) { 135 return false, nil 136 } 137 } 138 return true, nil 139 } 140 141 type mergeState uint8 142 143 const ( 144 msInit mergeState = iota 145 msExhaustCheck 146 msCompare 147 msIncLeft 148 msIncRight 149 msSelect 150 msRet 151 msRetLeft 152 msRejectNull 153 ) 154 155 func (i *mergeJoinIter) Next(ctx *sql.Context) (sql.Row, error) { 156 var err error 157 var ret sql.Row 158 var res int 159 160 // The common inner join match flow: 161 // 1) check for io.EOF 162 // 2) evaluate compare filter 163 // 3) evaluate select filters 164 // 4) initialize match state 165 // 5) drain match state 166 // 6) repeat 167 // 168 // Left-join matching is unique. At any given time, we need to know whether 169 // a unique left row: 1) has already matched, 2) has more right rows 170 // available for matching before we can return a nullified-row. Otherwise 171 // we may accidentally return nullified rows that have matches (before or 172 // after the current row), or fail to return a nullified row that has no 173 // matches. 174 // 175 // We use two variables to manage the lookahead state management. 176 // |matchedleft| is a forward-looking indicator of whether the current left 177 // row has satisfied a join condition. It is reset to false when we 178 // increment left. |matchincleft| is true when the most recent call to 179 // |incmatch| incremented the left row. The two vars combined let us 180 // lookahead during msSelect to 1) identify proper nullified row matches, 181 // and 2) maintain forward-looking state for the next |i.fullrow|. 182 // 183 nextState := msInit 184 for { 185 switch nextState { 186 case msInit: 187 if !i.init { 188 err = i.initIters(ctx) 189 if err != nil { 190 return nil, err 191 } 192 } 193 nextState = msExhaustCheck 194 case msExhaustCheck: 195 if i.lojFinalize() { 196 ret = i.copyReturnRow() 197 nextState = msRetLeft 198 } else if i.exhausted() { 199 return nil, io.EOF 200 } else { 201 nextState = msCompare 202 } 203 case msCompare: 204 res, err = i.cmp.Compare(ctx, i.fullRow) 205 if expression.ErrNilOperand.Is(err) { 206 nextState = msRejectNull 207 break 208 } else if err != nil { 209 return nil, err 210 } 211 switch { 212 case res < 0: 213 if i.typ.IsLeftOuter() { 214 if i.leftMatched { 215 nextState = msIncLeft 216 } else { 217 ret = i.copyReturnRow() 218 nextState = msRetLeft 219 } 220 } else { 221 nextState = msIncLeft 222 } 223 case res > 0: 224 nextState = msIncRight 225 case res == 0: 226 nextState = msSelect 227 } 228 case msRejectNull: 229 left, _ := i.cmp.Left().Eval(ctx, i.fullRow) 230 if left == nil { 231 if i.typ.IsLeftOuter() && !i.leftMatched { 232 ret = i.copyReturnRow() 233 nextState = msRetLeft 234 } else { 235 nextState = msIncLeft 236 } 237 } else { 238 nextState = msIncRight 239 } 240 case msIncLeft: 241 err = i.incLeft(ctx) 242 nextState = msExhaustCheck 243 case msIncRight: 244 err = i.incRight(ctx) 245 nextState = msExhaustCheck 246 case msSelect: 247 ret = i.copyReturnRow() 248 currLeftMatched := i.leftMatched 249 250 ok, err := i.sel(ctx, ret) 251 if err != nil { 252 return nil, err 253 } 254 err = i.incMatch(ctx) 255 if err != nil { 256 return nil, err 257 } 258 if ok { 259 if !i.matchIncLeft { 260 // |leftMatched| is forward-looking, sets state for 261 // current |i.fullRow| (next |ret|) 262 i.leftMatched = true 263 } 264 265 nextState = msRet 266 break 267 } 268 269 if !i.typ.IsLeftOuter() { 270 nextState = msExhaustCheck 271 break 272 } 273 274 if i.matchIncLeft && !currLeftMatched { 275 // |i.matchIncLeft| indicates whether the most recent 276 // |i.incMatch| call incremented the left row. 277 // |currLeftMatched| indicates whether |ret| has already 278 // successfully met a join condition. 279 return i.removeParentRow(i.nullifyRightRow(ret)), nil 280 } else { 281 nextState = msExhaustCheck 282 } 283 284 case msRet: 285 return i.removeParentRow(ret), nil 286 case msRetLeft: 287 ret = i.removeParentRow(i.nullifyRightRow(ret)) 288 err = i.incLeft(ctx) 289 if err != nil { 290 return nil, err 291 } 292 return ret, nil 293 } 294 } 295 } 296 297 func (i *mergeJoinIter) copyReturnRow() sql.Row { 298 ret := make(sql.Row, len(i.fullRow)) 299 copy(ret, i.fullRow) 300 return ret 301 } 302 303 // incMatch uses two phases to find all left and right rows that match their 304 // companion rows for the given match stats: 305 // 1. collect all right rows that match the current left row into a buffer; 306 // 2. for every left row that matches the original right row, match every 307 // right row. 308 // 309 // We maintain lookaheads for the first non-matching row in each iter. If 310 // there is no next non-matching row (io.EOF), we trigger |i.exhausted| at 311 // the appropriate time depending on whether we are left-joining. 312 func (i *mergeJoinIter) incMatch(ctx *sql.Context) error { 313 i.matchIncLeft = false 314 315 if !i.rightDone { 316 // initialize right matches buffer 317 right := make(sql.Row, i.rightRowLen) 318 copy(right, i.fullRow[i.scopeLen+i.parentLen+i.leftRowLen:]) 319 i.rightBuf = append(i.rightBuf, right) 320 321 match := true 322 var err error 323 var peek sql.Row 324 for match { 325 match, peek, err = i.peekMatch(ctx, i.right) 326 if err != nil { 327 return err 328 } else if match { 329 i.rightBuf = append(i.rightBuf, peek) 330 } else { 331 i.rightPeek = peek 332 i.rightDone = true 333 } 334 } 335 // left row 1 and right row 1 is a duplicate of the first match 336 // captured in outer closure, slough one iteration 337 err = i.incMatch(ctx) 338 if err != nil { 339 return err 340 } 341 342 } 343 344 if i.bufI > len(i.rightBuf)-1 { 345 // matched entire right buffer to the current left row, reset 346 i.matchIncLeft = true 347 i.bufI = 0 348 match, peek, err := i.peekMatch(ctx, i.left) 349 if err != nil { 350 return err 351 } else if !match { 352 i.leftPeek = peek 353 i.leftDone = true 354 } 355 i.leftMatched = false 356 } 357 358 if !i.leftDone { 359 // rightBuf has already been validated, we don't need compare 360 copySubslice(i.fullRow, i.rightBuf[i.bufI], i.scopeLen+i.parentLen+i.leftRowLen) 361 i.bufI++ 362 return nil 363 } 364 365 defer i.resetMatchState() 366 367 if i.leftPeek == nil { 368 i.leftExhausted = true 369 } 370 if i.rightPeek == nil { 371 i.rightExhausted = true 372 } 373 374 if i.exhausted() { 375 if i.lojFinalize() { 376 // left joins expect the left row in |i.fullRow| as long 377 // as the left iter is not exhausted. 378 copySubslice(i.fullRow, i.leftPeek, i.scopeLen+i.parentLen) 379 } 380 return nil 381 } 382 383 // both lookaheads fail the join condition. Drain 384 // lookahead rows / increment both iterators. 385 i.matchIncLeft = true 386 copySubslice(i.fullRow, i.leftPeek, i.scopeLen+i.parentLen) 387 copySubslice(i.fullRow, i.rightPeek, i.scopeLen+i.parentLen+i.leftRowLen) 388 389 return nil 390 } 391 392 // lojFinalize is a unique state where we have exhausted the outer iterator, 393 // but not the inner iterator we are outer joining against. 394 func (i *mergeJoinIter) lojFinalize() bool { 395 return i.rightExhausted && !i.leftExhausted && i.typ.IsLeftOuter() 396 } 397 398 // nullifyRightRow sets the values corresponding to the right row to nil 399 func (i *mergeJoinIter) nullifyRightRow(r sql.Row) sql.Row { 400 for j := i.scopeLen + i.parentLen + i.leftRowLen; j < len(r); j++ { 401 r[j] = nil 402 } 403 return r 404 } 405 406 // initIters populates i.fullRow and clears the match state 407 func (i *mergeJoinIter) initIters(ctx *sql.Context) error { 408 err := i.incLeft(ctx) 409 if err != nil { 410 return err 411 } 412 err = i.incRight(ctx) 413 if err != nil { 414 return err 415 } 416 i.init = true 417 i.resetMatchState() 418 return nil 419 } 420 421 // resetMatchState clears the match state variables to zero values 422 func (i *mergeJoinIter) resetMatchState() { 423 i.leftPeek = nil 424 i.rightPeek = nil 425 i.leftDone = false 426 i.rightDone = false 427 i.rightBuf = i.rightBuf[:0] 428 i.bufI = 0 429 } 430 431 // peekMatch reads the next row from an iterator, attempts to update i.fullRow 432 // to find a matching condition, rewinding the change in the case of failure. 433 // We return whether a successful match was found, the lookahead row for saving 434 // in the case of failure, and an error or nil. If the iterator io.EOFs, we return 435 // no match, no lookahead row, and no error. 436 func (i *mergeJoinIter) peekMatch(ctx *sql.Context, iter sql.RowIter) (bool, sql.Row, error) { 437 var off int 438 var restore sql.Row 439 switch iter { 440 case i.left: 441 off = i.scopeLen + i.parentLen 442 restore = make(sql.Row, i.leftRowLen) 443 copy(restore, i.fullRow[off:off+i.leftRowLen]) 444 case i.right: 445 off = i.scopeLen + i.parentLen + i.leftRowLen 446 restore = make(sql.Row, i.rightRowLen) 447 copy(restore, i.fullRow[off:off+i.rightRowLen]) 448 default: 449 } 450 451 // peek lookahead 452 peek, err := iter.Next(ctx) 453 if errors.Is(err, io.EOF) { 454 // io.EOF is the only nil row nil err return 455 return false, nil, nil 456 } else if err != nil { 457 return false, nil, err 458 } 459 460 // check if lookahead valid 461 copySubslice(i.fullRow, peek, off) 462 res, err := i.cmp.Compare(ctx, i.fullRow) 463 if expression.ErrNilOperand.Is(err) { 464 // revert change to output row if no match 465 copySubslice(i.fullRow, restore, off) 466 } else if err != nil { 467 return false, nil, err 468 } 469 if res != 0 { 470 // revert change to output row if no match 471 copySubslice(i.fullRow, restore, off) 472 } 473 return res == 0, peek, nil 474 } 475 476 // exhausted returns true if either iterator has io.EOF'd 477 func (i *mergeJoinIter) exhausted() bool { 478 return i.leftExhausted || i.rightExhausted 479 } 480 481 // copySubslice copies |src| into |dst| starting at index |off| 482 func copySubslice(dst, src sql.Row, off int) { 483 for i, v := range src { 484 dst[off+i] = v 485 } 486 } 487 488 // incLeft updates |i.fullRow|'s left row 489 func (i *mergeJoinIter) incLeft(ctx *sql.Context) error { 490 i.leftMatched = false 491 var row sql.Row 492 var err error 493 if i.leftPeek != nil { 494 row = i.leftPeek 495 i.leftPeek = nil 496 } else { 497 row, err = i.left.Next(ctx) 498 if errors.Is(err, io.EOF) { 499 i.leftExhausted = true 500 return nil 501 } else if err != nil { 502 return err 503 } 504 } 505 506 off := i.scopeLen + i.parentLen 507 for j, v := range row { 508 i.fullRow[off+j] = v 509 } 510 511 return nil 512 } 513 514 // incRight updates |i.fullRow|'s right row 515 func (i *mergeJoinIter) incRight(ctx *sql.Context) error { 516 var row sql.Row 517 var err error 518 if i.rightPeek != nil { 519 row = i.rightPeek 520 i.rightPeek = nil 521 } else { 522 row, err = i.right.Next(ctx) 523 if errors.Is(err, io.EOF) { 524 i.rightExhausted = true 525 return nil 526 } else if err != nil { 527 return err 528 } 529 } 530 531 off := i.scopeLen + i.parentLen + i.leftRowLen 532 for j, v := range row { 533 i.fullRow[off+j] = v 534 } 535 536 return nil 537 } 538 539 // incLeft updates |i.fullRow|'s |inRow| 540 func (i *mergeJoinIter) incIter(ctx *sql.Context, iter sql.RowIter, off int) error { 541 row, err := iter.Next(ctx) 542 if err != nil { 543 return err 544 } 545 for j, v := range row { 546 i.fullRow[off+j] = v 547 } 548 return nil 549 } 550 551 func (i *mergeJoinIter) removeParentRow(r sql.Row) sql.Row { 552 copy(r[i.scopeLen:], r[i.scopeLen+i.parentLen:]) 553 r = r[:len(r)-i.parentLen] 554 return r 555 } 556 557 func (i *mergeJoinIter) Close(ctx *sql.Context) (err error) { 558 if i.left != nil { 559 err = i.left.Close(ctx) 560 } 561 562 if i.right != nil { 563 if err == nil { 564 err = i.right.Close(ctx) 565 } else { 566 i.right.Close(ctx) 567 } 568 } 569 570 return err 571 }