github.com/dolthub/go-mysql-server@v0.18.0/sql/rowexec/merge_join.go (about)

     1  // Copyright 2022 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package rowexec
    16  
    17  import (
    18  	"errors"
    19  	"io"
    20  
    21  	"github.com/dolthub/go-mysql-server/sql/plan"
    22  
    23  	"github.com/dolthub/go-mysql-server/sql"
    24  	"github.com/dolthub/go-mysql-server/sql/expression"
    25  )
    26  
    27  var ErrMergeJoinExpectsComparerFilters = errors.New("merge join expects expression.Comparer filters, found: %T")
    28  
    29  // NewMergeJoin returns a node that performs a presorted merge join on
    30  // two relations. We require 1) the join filter is an equality with disjoint
    31  // join attributes, 2) the free attributes for a relation are a prefix for
    32  // an index that will be used to return sorted rows.
    33  func NewMergeJoin(left, right sql.Node, cond sql.Expression) *plan.JoinNode {
    34  	return plan.NewJoin(left, right, plan.JoinTypeMerge, cond)
    35  }
    36  
    37  func NewLeftMergeJoin(left, right sql.Node, cond sql.Expression) *plan.JoinNode {
    38  	return plan.NewJoin(left, right, plan.JoinTypeLeftOuterMerge, cond)
    39  }
    40  
    41  func newMergeJoinIter(ctx *sql.Context, b sql.NodeExecBuilder, j *plan.JoinNode, row sql.Row) (sql.RowIter, error) {
    42  	l, err := b.Build(ctx, j.Left(), row)
    43  	if err != nil {
    44  		return nil, err
    45  	}
    46  	r, err := b.Build(ctx, j.Right(), row)
    47  	if err != nil {
    48  		return nil, err
    49  	}
    50  
    51  	fullRow := make(sql.Row, len(row)+len(j.Left().Schema())+len(j.Right().Schema()))
    52  	fullRow[0] = row
    53  	if len(row) > 0 {
    54  		copy(fullRow[0:], row[:])
    55  	}
    56  
    57  	// a merge join's first filter provides direction information
    58  	// for which iter to update next
    59  	filters := expression.SplitConjunction(j.Filter)
    60  	cmp, ok := filters[0].(expression.Comparer)
    61  	if !ok {
    62  		return nil, sql.ErrMergeJoinExpectsComparerFilters.New(filters[0])
    63  	}
    64  
    65  	if len(filters) == 0 {
    66  		return nil, sql.ErrNoJoinFilters.New()
    67  	}
    68  
    69  	var iter sql.RowIter = &mergeJoinIter{
    70  		left:        l,
    71  		right:       r,
    72  		filters:     filters[1:],
    73  		cmp:         cmp,
    74  		typ:         j.Op,
    75  		fullRow:     fullRow,
    76  		scopeLen:    j.ScopeLen,
    77  		parentLen:   len(row) - j.ScopeLen,
    78  		leftRowLen:  len(j.Left().Schema()),
    79  		rightRowLen: len(j.Right().Schema()),
    80  	}
    81  	return iter, nil
    82  }
    83  
    84  // mergeJoinIter alternates incrementing two RowIters, assuming
    85  // rows will be provided in a sorted order given the join |expr|
    86  // (see sortedIndexScanForTableCol). Extra join |filters| that do
    87  // not provide a directional ordering signal for index iteration
    88  // are evaluated separately.
    89  type mergeJoinIter struct {
    90  	// cmp is a directional indicator for row iter increments
    91  	cmp expression.Comparer
    92  	// filters is the remaining set of join conditions
    93  	filters []sql.Expression
    94  	left    sql.RowIter
    95  	right   sql.RowIter
    96  	fullRow sql.Row
    97  
    98  	// match lookahead buffers and state tracking (private to match)
    99  	rightBuf  []sql.Row
   100  	bufI      int
   101  	rightPeek sql.Row
   102  	leftPeek  sql.Row
   103  	rightDone bool
   104  	leftDone  bool
   105  
   106  	// matchIncLeft indicates whether the most recent |i.incMatch|
   107  	// call incremented the left row.
   108  	matchIncLeft bool
   109  	// leftMatched indicates whether the current left in |i.fullRow|
   110  	// has satisfied the join condition.
   111  	leftMatched bool
   112  
   113  	// lifecycle maintenance
   114  	init           bool
   115  	leftExhausted  bool
   116  	rightExhausted bool
   117  
   118  	typ         plan.JoinType
   119  	scopeLen    int
   120  	leftRowLen  int
   121  	rightRowLen int
   122  	parentLen   int
   123  }
   124  
   125  var _ sql.RowIter = (*mergeJoinIter)(nil)
   126  
   127  func (i *mergeJoinIter) sel(ctx *sql.Context, row sql.Row) (bool, error) {
   128  	for _, f := range i.filters {
   129  		res, err := sql.EvaluateCondition(ctx, f, row)
   130  		if err != nil {
   131  			return false, err
   132  		}
   133  
   134  		if !sql.IsTrue(res) {
   135  			return false, nil
   136  		}
   137  	}
   138  	return true, nil
   139  }
   140  
   141  type mergeState uint8
   142  
   143  const (
   144  	msInit mergeState = iota
   145  	msExhaustCheck
   146  	msCompare
   147  	msIncLeft
   148  	msIncRight
   149  	msSelect
   150  	msRet
   151  	msRetLeft
   152  	msRejectNull
   153  )
   154  
   155  func (i *mergeJoinIter) Next(ctx *sql.Context) (sql.Row, error) {
   156  	var err error
   157  	var ret sql.Row
   158  	var res int
   159  
   160  	//  The common inner join match flow:
   161  	//  1) check for io.EOF
   162  	//  2) evaluate compare filter
   163  	//  3) evaluate select filters
   164  	//  4) initialize match state
   165  	//  5) drain match state
   166  	//  6) repeat
   167  	//
   168  	// Left-join matching is unique. At any given time, we need to know whether
   169  	// a unique left row: 1) has already matched, 2) has more right rows
   170  	// available for matching before we can return a nullified-row. Otherwise
   171  	// we may accidentally return nullified rows that have matches (before or
   172  	// after the current row), or fail to return a nullified row that has no
   173  	// matches.
   174  	//
   175  	// We use two variables to manage the lookahead state management.
   176  	// |matchedleft| is a forward-looking indicator of whether the current left
   177  	// row has satisfied a join condition. It is reset to false when we
   178  	// increment left. |matchincleft| is true when the most recent call to
   179  	// |incmatch| incremented the left row. The two vars combined let us
   180  	// lookahead during msSelect to 1) identify proper nullified row matches,
   181  	// and 2) maintain forward-looking state for the next |i.fullrow|.
   182  	//
   183  	nextState := msInit
   184  	for {
   185  		switch nextState {
   186  		case msInit:
   187  			if !i.init {
   188  				err = i.initIters(ctx)
   189  				if err != nil {
   190  					return nil, err
   191  				}
   192  			}
   193  			nextState = msExhaustCheck
   194  		case msExhaustCheck:
   195  			if i.lojFinalize() {
   196  				ret = i.copyReturnRow()
   197  				nextState = msRetLeft
   198  			} else if i.exhausted() {
   199  				return nil, io.EOF
   200  			} else {
   201  				nextState = msCompare
   202  			}
   203  		case msCompare:
   204  			res, err = i.cmp.Compare(ctx, i.fullRow)
   205  			if expression.ErrNilOperand.Is(err) {
   206  				nextState = msRejectNull
   207  				break
   208  			} else if err != nil {
   209  				return nil, err
   210  			}
   211  			switch {
   212  			case res < 0:
   213  				if i.typ.IsLeftOuter() {
   214  					if i.leftMatched {
   215  						nextState = msIncLeft
   216  					} else {
   217  						ret = i.copyReturnRow()
   218  						nextState = msRetLeft
   219  					}
   220  				} else {
   221  					nextState = msIncLeft
   222  				}
   223  			case res > 0:
   224  				nextState = msIncRight
   225  			case res == 0:
   226  				nextState = msSelect
   227  			}
   228  		case msRejectNull:
   229  			left, _ := i.cmp.Left().Eval(ctx, i.fullRow)
   230  			if left == nil {
   231  				if i.typ.IsLeftOuter() && !i.leftMatched {
   232  					ret = i.copyReturnRow()
   233  					nextState = msRetLeft
   234  				} else {
   235  					nextState = msIncLeft
   236  				}
   237  			} else {
   238  				nextState = msIncRight
   239  			}
   240  		case msIncLeft:
   241  			err = i.incLeft(ctx)
   242  			nextState = msExhaustCheck
   243  		case msIncRight:
   244  			err = i.incRight(ctx)
   245  			nextState = msExhaustCheck
   246  		case msSelect:
   247  			ret = i.copyReturnRow()
   248  			currLeftMatched := i.leftMatched
   249  
   250  			ok, err := i.sel(ctx, ret)
   251  			if err != nil {
   252  				return nil, err
   253  			}
   254  			err = i.incMatch(ctx)
   255  			if err != nil {
   256  				return nil, err
   257  			}
   258  			if ok {
   259  				if !i.matchIncLeft {
   260  					// |leftMatched| is forward-looking, sets state for
   261  					// current |i.fullRow| (next |ret|)
   262  					i.leftMatched = true
   263  				}
   264  
   265  				nextState = msRet
   266  				break
   267  			}
   268  
   269  			if !i.typ.IsLeftOuter() {
   270  				nextState = msExhaustCheck
   271  				break
   272  			}
   273  
   274  			if i.matchIncLeft && !currLeftMatched {
   275  				// |i.matchIncLeft| indicates whether the most recent
   276  				// |i.incMatch| call incremented the left row.
   277  				// |currLeftMatched| indicates whether |ret| has already
   278  				// successfully met a join condition.
   279  				return i.removeParentRow(i.nullifyRightRow(ret)), nil
   280  			} else {
   281  				nextState = msExhaustCheck
   282  			}
   283  
   284  		case msRet:
   285  			return i.removeParentRow(ret), nil
   286  		case msRetLeft:
   287  			ret = i.removeParentRow(i.nullifyRightRow(ret))
   288  			err = i.incLeft(ctx)
   289  			if err != nil {
   290  				return nil, err
   291  			}
   292  			return ret, nil
   293  		}
   294  	}
   295  }
   296  
   297  func (i *mergeJoinIter) copyReturnRow() sql.Row {
   298  	ret := make(sql.Row, len(i.fullRow))
   299  	copy(ret, i.fullRow)
   300  	return ret
   301  }
   302  
   303  // incMatch uses two phases to find all left and right rows that match their
   304  // companion rows for the given match stats:
   305  //  1. collect all right rows that match the current left row into a buffer;
   306  //  2. for every left row that matches the original right row, match every
   307  //     right row.
   308  //
   309  // We maintain lookaheads for the first non-matching row in each iter. If
   310  // there is no next non-matching row (io.EOF), we trigger |i.exhausted| at
   311  // the appropriate time depending on whether we are left-joining.
   312  func (i *mergeJoinIter) incMatch(ctx *sql.Context) error {
   313  	i.matchIncLeft = false
   314  
   315  	if !i.rightDone {
   316  		// initialize right matches buffer
   317  		right := make(sql.Row, i.rightRowLen)
   318  		copy(right, i.fullRow[i.scopeLen+i.parentLen+i.leftRowLen:])
   319  		i.rightBuf = append(i.rightBuf, right)
   320  
   321  		match := true
   322  		var err error
   323  		var peek sql.Row
   324  		for match {
   325  			match, peek, err = i.peekMatch(ctx, i.right)
   326  			if err != nil {
   327  				return err
   328  			} else if match {
   329  				i.rightBuf = append(i.rightBuf, peek)
   330  			} else {
   331  				i.rightPeek = peek
   332  				i.rightDone = true
   333  			}
   334  		}
   335  		// left row 1 and right row 1 is a duplicate of the first match
   336  		// captured in outer closure, slough one iteration
   337  		err = i.incMatch(ctx)
   338  		if err != nil {
   339  			return err
   340  		}
   341  
   342  	}
   343  
   344  	if i.bufI > len(i.rightBuf)-1 {
   345  		// matched entire right buffer to the current left row, reset
   346  		i.matchIncLeft = true
   347  		i.bufI = 0
   348  		match, peek, err := i.peekMatch(ctx, i.left)
   349  		if err != nil {
   350  			return err
   351  		} else if !match {
   352  			i.leftPeek = peek
   353  			i.leftDone = true
   354  		}
   355  		i.leftMatched = false
   356  	}
   357  
   358  	if !i.leftDone {
   359  		// rightBuf has already been validated, we don't need compare
   360  		copySubslice(i.fullRow, i.rightBuf[i.bufI], i.scopeLen+i.parentLen+i.leftRowLen)
   361  		i.bufI++
   362  		return nil
   363  	}
   364  
   365  	defer i.resetMatchState()
   366  
   367  	if i.leftPeek == nil {
   368  		i.leftExhausted = true
   369  	}
   370  	if i.rightPeek == nil {
   371  		i.rightExhausted = true
   372  	}
   373  
   374  	if i.exhausted() {
   375  		if i.lojFinalize() {
   376  			// left joins expect the left row in |i.fullRow| as long
   377  			// as the left iter is not exhausted.
   378  			copySubslice(i.fullRow, i.leftPeek, i.scopeLen+i.parentLen)
   379  		}
   380  		return nil
   381  	}
   382  
   383  	// both lookaheads fail the join condition. Drain
   384  	// lookahead rows / increment both iterators.
   385  	i.matchIncLeft = true
   386  	copySubslice(i.fullRow, i.leftPeek, i.scopeLen+i.parentLen)
   387  	copySubslice(i.fullRow, i.rightPeek, i.scopeLen+i.parentLen+i.leftRowLen)
   388  
   389  	return nil
   390  }
   391  
   392  // lojFinalize is a unique state where we have exhausted the outer iterator,
   393  // but not the inner iterator we are outer joining against.
   394  func (i *mergeJoinIter) lojFinalize() bool {
   395  	return i.rightExhausted && !i.leftExhausted && i.typ.IsLeftOuter()
   396  }
   397  
   398  // nullifyRightRow sets the values corresponding to the right row to nil
   399  func (i *mergeJoinIter) nullifyRightRow(r sql.Row) sql.Row {
   400  	for j := i.scopeLen + i.parentLen + i.leftRowLen; j < len(r); j++ {
   401  		r[j] = nil
   402  	}
   403  	return r
   404  }
   405  
   406  // initIters populates i.fullRow and clears the match state
   407  func (i *mergeJoinIter) initIters(ctx *sql.Context) error {
   408  	err := i.incLeft(ctx)
   409  	if err != nil {
   410  		return err
   411  	}
   412  	err = i.incRight(ctx)
   413  	if err != nil {
   414  		return err
   415  	}
   416  	i.init = true
   417  	i.resetMatchState()
   418  	return nil
   419  }
   420  
   421  // resetMatchState clears the match state variables to zero values
   422  func (i *mergeJoinIter) resetMatchState() {
   423  	i.leftPeek = nil
   424  	i.rightPeek = nil
   425  	i.leftDone = false
   426  	i.rightDone = false
   427  	i.rightBuf = i.rightBuf[:0]
   428  	i.bufI = 0
   429  }
   430  
   431  // peekMatch reads the next row from an iterator, attempts to update i.fullRow
   432  // to find a matching condition, rewinding the change in the case of failure.
   433  // We return whether a successful match was found, the lookahead row for saving
   434  // in the case of failure, and an error or nil. If the iterator io.EOFs, we return
   435  // no match, no lookahead row, and no error.
   436  func (i *mergeJoinIter) peekMatch(ctx *sql.Context, iter sql.RowIter) (bool, sql.Row, error) {
   437  	var off int
   438  	var restore sql.Row
   439  	switch iter {
   440  	case i.left:
   441  		off = i.scopeLen + i.parentLen
   442  		restore = make(sql.Row, i.leftRowLen)
   443  		copy(restore, i.fullRow[off:off+i.leftRowLen])
   444  	case i.right:
   445  		off = i.scopeLen + i.parentLen + i.leftRowLen
   446  		restore = make(sql.Row, i.rightRowLen)
   447  		copy(restore, i.fullRow[off:off+i.rightRowLen])
   448  	default:
   449  	}
   450  
   451  	// peek lookahead
   452  	peek, err := iter.Next(ctx)
   453  	if errors.Is(err, io.EOF) {
   454  		// io.EOF is the only nil row nil err return
   455  		return false, nil, nil
   456  	} else if err != nil {
   457  		return false, nil, err
   458  	}
   459  
   460  	// check if lookahead valid
   461  	copySubslice(i.fullRow, peek, off)
   462  	res, err := i.cmp.Compare(ctx, i.fullRow)
   463  	if expression.ErrNilOperand.Is(err) {
   464  		// revert change to output row if no match
   465  		copySubslice(i.fullRow, restore, off)
   466  	} else if err != nil {
   467  		return false, nil, err
   468  	}
   469  	if res != 0 {
   470  		// revert change to output row if no match
   471  		copySubslice(i.fullRow, restore, off)
   472  	}
   473  	return res == 0, peek, nil
   474  }
   475  
   476  // exhausted returns true if either iterator has io.EOF'd
   477  func (i *mergeJoinIter) exhausted() bool {
   478  	return i.leftExhausted || i.rightExhausted
   479  }
   480  
   481  // copySubslice copies |src| into |dst| starting at index |off|
   482  func copySubslice(dst, src sql.Row, off int) {
   483  	for i, v := range src {
   484  		dst[off+i] = v
   485  	}
   486  }
   487  
   488  // incLeft updates |i.fullRow|'s left row
   489  func (i *mergeJoinIter) incLeft(ctx *sql.Context) error {
   490  	i.leftMatched = false
   491  	var row sql.Row
   492  	var err error
   493  	if i.leftPeek != nil {
   494  		row = i.leftPeek
   495  		i.leftPeek = nil
   496  	} else {
   497  		row, err = i.left.Next(ctx)
   498  		if errors.Is(err, io.EOF) {
   499  			i.leftExhausted = true
   500  			return nil
   501  		} else if err != nil {
   502  			return err
   503  		}
   504  	}
   505  
   506  	off := i.scopeLen + i.parentLen
   507  	for j, v := range row {
   508  		i.fullRow[off+j] = v
   509  	}
   510  
   511  	return nil
   512  }
   513  
   514  // incRight updates |i.fullRow|'s right row
   515  func (i *mergeJoinIter) incRight(ctx *sql.Context) error {
   516  	var row sql.Row
   517  	var err error
   518  	if i.rightPeek != nil {
   519  		row = i.rightPeek
   520  		i.rightPeek = nil
   521  	} else {
   522  		row, err = i.right.Next(ctx)
   523  		if errors.Is(err, io.EOF) {
   524  			i.rightExhausted = true
   525  			return nil
   526  		} else if err != nil {
   527  			return err
   528  		}
   529  	}
   530  
   531  	off := i.scopeLen + i.parentLen + i.leftRowLen
   532  	for j, v := range row {
   533  		i.fullRow[off+j] = v
   534  	}
   535  
   536  	return nil
   537  }
   538  
   539  // incLeft updates |i.fullRow|'s |inRow|
   540  func (i *mergeJoinIter) incIter(ctx *sql.Context, iter sql.RowIter, off int) error {
   541  	row, err := iter.Next(ctx)
   542  	if err != nil {
   543  		return err
   544  	}
   545  	for j, v := range row {
   546  		i.fullRow[off+j] = v
   547  	}
   548  	return nil
   549  }
   550  
   551  func (i *mergeJoinIter) removeParentRow(r sql.Row) sql.Row {
   552  	copy(r[i.scopeLen:], r[i.scopeLen+i.parentLen:])
   553  	r = r[:len(r)-i.parentLen]
   554  	return r
   555  }
   556  
   557  func (i *mergeJoinIter) Close(ctx *sql.Context) (err error) {
   558  	if i.left != nil {
   559  		err = i.left.Close(ctx)
   560  	}
   561  
   562  	if i.right != nil {
   563  		if err == nil {
   564  			err = i.right.Close(ctx)
   565  		} else {
   566  			i.right.Close(ctx)
   567  		}
   568  	}
   569  
   570  	return err
   571  }