github.com/dolthub/go-mysql-server@v0.18.0/sql/rowexec/range_heap_iter.go (about)

     1  package rowexec
     2  
     3  import (
     4  	"container/heap"
     5  	"errors"
     6  	"io"
     7  	"reflect"
     8  
     9  	"go.opentelemetry.io/otel/attribute"
    10  	"go.opentelemetry.io/otel/trace"
    11  
    12  	"github.com/dolthub/go-mysql-server/sql"
    13  	"github.com/dolthub/go-mysql-server/sql/plan"
    14  )
    15  
    16  func newRangeHeapJoinIter(ctx *sql.Context, b sql.NodeExecBuilder, j *plan.JoinNode, row sql.Row) (sql.RowIter, error) {
    17  	var leftName, rightName string
    18  	if leftTable, ok := j.Left().(sql.Nameable); ok {
    19  		leftName = leftTable.Name()
    20  	} else {
    21  		leftName = reflect.TypeOf(j.Left()).String()
    22  	}
    23  
    24  	if rightTable, ok := j.Right().(sql.Nameable); ok {
    25  		rightName = rightTable.Name()
    26  	} else {
    27  		rightName = reflect.TypeOf(j.Right()).String()
    28  	}
    29  
    30  	span, ctx := ctx.Span("plan.rangeHeapJoinIter", trace.WithAttributes(
    31  		attribute.String("left", leftName),
    32  		attribute.String("right", rightName),
    33  	))
    34  
    35  	l, err := b.Build(ctx, j.Left(), row)
    36  	if err != nil {
    37  		span.End()
    38  		return nil, err
    39  	}
    40  
    41  	rhp, ok := j.Right().(*plan.RangeHeap)
    42  	if !ok {
    43  		return nil, errors.New("right side of join must be a range heap")
    44  	}
    45  
    46  	return sql.NewSpanIter(span, &rangeHeapJoinIter{
    47  		parentRow:     row,
    48  		primary:       l,
    49  		cond:          j.Filter,
    50  		joinType:      j.Op,
    51  		rowSize:       len(row) + len(j.Left().Schema()) + len(j.Right().Schema()),
    52  		scopeLen:      j.ScopeLen,
    53  		b:             b,
    54  		rangeHeapPlan: rhp,
    55  	}), nil
    56  }
    57  
    58  // joinIter is an iterator that iterates over every row in the primary table and performs an index lookup in
    59  // the secondary table for each value
    60  type rangeHeapJoinIter struct {
    61  	parentRow  sql.Row
    62  	primary    sql.RowIter
    63  	primaryRow sql.Row
    64  	secondary  sql.RowIter
    65  	cond       sql.Expression
    66  	joinType   plan.JoinType
    67  
    68  	foundMatch bool
    69  	rowSize    int
    70  	scopeLen   int
    71  	b          sql.NodeExecBuilder
    72  
    73  	rangeHeapPlan *plan.RangeHeap
    74  	childRowIter  sql.RowIter
    75  	pendingRow    sql.Row
    76  
    77  	activeRanges []sql.Row
    78  	err          error
    79  }
    80  
    81  func (iter *rangeHeapJoinIter) loadPrimary(ctx *sql.Context) error {
    82  	if iter.primaryRow == nil {
    83  		r, err := iter.primary.Next(ctx)
    84  		if err != nil {
    85  			return err
    86  		}
    87  
    88  		iter.primaryRow = iter.parentRow.Append(r)
    89  		iter.foundMatch = false
    90  
    91  		err = iter.initializeHeap(ctx, iter.b, iter.primaryRow)
    92  		if err != nil {
    93  			return err
    94  		}
    95  	}
    96  
    97  	return nil
    98  }
    99  
   100  func (iter *rangeHeapJoinIter) loadSecondary(ctx *sql.Context) (sql.Row, error) {
   101  	if iter.secondary == nil {
   102  		rowIter, err := iter.getActiveRanges(ctx, iter.b, iter.primaryRow)
   103  
   104  		if err != nil {
   105  			return nil, err
   106  		}
   107  		if plan.IsEmptyIter(rowIter) {
   108  			return nil, plan.ErrEmptyCachedResult
   109  		}
   110  		iter.secondary = rowIter
   111  	}
   112  
   113  	secondaryRow, err := iter.secondary.Next(ctx)
   114  	if err != nil {
   115  		if err == io.EOF {
   116  			err = iter.secondary.Close(ctx)
   117  			iter.secondary = nil
   118  			if err != nil {
   119  				return nil, err
   120  			}
   121  			iter.primaryRow = nil
   122  			return nil, io.EOF
   123  		}
   124  		return nil, err
   125  	}
   126  
   127  	return secondaryRow, nil
   128  }
   129  
   130  func (iter *rangeHeapJoinIter) Next(ctx *sql.Context) (sql.Row, error) {
   131  	for {
   132  		if err := iter.loadPrimary(ctx); err != nil {
   133  			return nil, err
   134  		}
   135  
   136  		primary := iter.primaryRow
   137  		secondary, err := iter.loadSecondary(ctx)
   138  		if err != nil {
   139  			if errors.Is(err, io.EOF) {
   140  				if !iter.foundMatch && iter.joinType.IsLeftOuter() {
   141  					iter.primaryRow = nil
   142  					row := iter.buildRow(primary, nil)
   143  					return iter.removeParentRow(row), nil
   144  				}
   145  				continue
   146  			} else if errors.Is(err, plan.ErrEmptyCachedResult) {
   147  				if !iter.foundMatch && iter.joinType.IsLeftOuter() {
   148  					iter.primaryRow = nil
   149  					row := iter.buildRow(primary, nil)
   150  					return iter.removeParentRow(row), nil
   151  				}
   152  
   153  				return nil, io.EOF
   154  			}
   155  			return nil, err
   156  		}
   157  
   158  		row := iter.buildRow(primary, secondary)
   159  		res, err := iter.cond.Eval(ctx, row)
   160  		matches := res == true
   161  		if err != nil {
   162  			return nil, err
   163  		}
   164  
   165  		if res == nil && iter.joinType.IsExcludeNulls() {
   166  			err = iter.secondary.Close(ctx)
   167  			iter.secondary = nil
   168  			if err != nil {
   169  				return nil, err
   170  			}
   171  			iter.primaryRow = nil
   172  			continue
   173  		}
   174  
   175  		if !matches {
   176  			continue
   177  		}
   178  
   179  		iter.foundMatch = true
   180  		return iter.removeParentRow(row), nil
   181  	}
   182  }
   183  
   184  func (iter *rangeHeapJoinIter) removeParentRow(r sql.Row) sql.Row {
   185  	copy(r[iter.scopeLen:], r[len(iter.parentRow):])
   186  	r = r[:len(r)-len(iter.parentRow)+iter.scopeLen]
   187  	return r
   188  }
   189  
   190  // buildRow builds the result set row using the rows from the primary and secondary tables
   191  func (iter *rangeHeapJoinIter) buildRow(primary, secondary sql.Row) sql.Row {
   192  	row := make(sql.Row, iter.rowSize)
   193  
   194  	copy(row, primary)
   195  	copy(row[len(primary):], secondary)
   196  
   197  	return row
   198  }
   199  
   200  func (iter *rangeHeapJoinIter) Close(ctx *sql.Context) (err error) {
   201  	if iter.primary != nil {
   202  		if err = iter.primary.Close(ctx); err != nil {
   203  			if iter.secondary != nil {
   204  				_ = iter.secondary.Close(ctx)
   205  			}
   206  			return err
   207  		}
   208  	}
   209  
   210  	if iter.secondary != nil {
   211  		err = iter.secondary.Close(ctx)
   212  		iter.secondary = nil
   213  	}
   214  
   215  	return err
   216  }
   217  
   218  func (iter *rangeHeapJoinIter) initializeHeap(ctx *sql.Context, builder sql.NodeExecBuilder, primaryRow sql.Row) (err error) {
   219  	iter.childRowIter, err = builder.Build(ctx, iter.rangeHeapPlan.Child, primaryRow)
   220  	if err != nil {
   221  		return err
   222  	}
   223  	iter.activeRanges = nil
   224  	iter.rangeHeapPlan.ComparisonType = iter.rangeHeapPlan.Schema()[iter.rangeHeapPlan.MaxColumnIndex].Type
   225  
   226  	iter.pendingRow, err = iter.childRowIter.Next(ctx)
   227  	if err == io.EOF {
   228  		iter.pendingRow = nil
   229  		return nil
   230  	}
   231  	return err
   232  }
   233  
   234  func (iter *rangeHeapJoinIter) getActiveRanges(ctx *sql.Context, _ sql.NodeExecBuilder, row sql.Row) (sql.RowIter, error) {
   235  	// Remove rows from the heap if we've advanced beyond their max value.
   236  	for iter.Len() > 0 {
   237  		maxValue := iter.Peek()
   238  		compareResult, err := compareNullsFirst(iter.rangeHeapPlan.ComparisonType, row[iter.rangeHeapPlan.ValueColumnIndex], maxValue)
   239  		if err != nil {
   240  			return nil, err
   241  		}
   242  		if (iter.rangeHeapPlan.RangeIsClosedAbove && compareResult > 0) || (!iter.rangeHeapPlan.RangeIsClosedAbove && compareResult >= 0) {
   243  			heap.Pop(iter)
   244  			if iter.err != nil {
   245  				err = iter.err
   246  				iter.err = nil
   247  				return nil, err
   248  			}
   249  		} else {
   250  			break
   251  		}
   252  	}
   253  
   254  	// Advance the child iterator until we encounter a row whose min value is beyond the range.
   255  	for iter.pendingRow != nil {
   256  		minValue := iter.pendingRow[iter.rangeHeapPlan.MinColumnIndex]
   257  		compareResult, err := compareNullsFirst(iter.rangeHeapPlan.ComparisonType, row[iter.rangeHeapPlan.ValueColumnIndex], minValue)
   258  		if err != nil {
   259  			return nil, err
   260  		}
   261  
   262  		if (iter.rangeHeapPlan.RangeIsClosedBelow && compareResult < 0) || (!iter.rangeHeapPlan.RangeIsClosedBelow && compareResult <= 0) {
   263  			break
   264  		} else {
   265  			heap.Push(iter, iter.pendingRow)
   266  			if iter.err != nil {
   267  				err = iter.err
   268  				iter.err = nil
   269  				return nil, err
   270  			}
   271  		}
   272  
   273  		iter.pendingRow, err = iter.childRowIter.Next(ctx)
   274  		if err != nil {
   275  			if errors.Is(err, io.EOF) {
   276  				// We've already imported every range into the priority queue.
   277  				iter.pendingRow = nil
   278  				break
   279  			}
   280  			return nil, err
   281  		}
   282  	}
   283  
   284  	// Every active row must match the accepted row.
   285  	return sql.RowsToRowIter(iter.activeRanges...), nil
   286  }
   287  
   288  // When managing the heap, consider all NULLs to come before any non-NULLS.
   289  // This is consistent with the order received if either child node is an index.
   290  // Note: We could get the same behavior by simply excluding values and ranges containing NULL,
   291  // but this is forward compatible if we ever want to convert joins with null-safe conditions into RangeHeapJoins.
   292  func compareNullsFirst(comparisonType sql.Type, a, b interface{}) (int, error) {
   293  	if a == nil {
   294  		if b == nil {
   295  			return 0, nil
   296  		} else {
   297  			return -1, nil
   298  		}
   299  	}
   300  	if b == nil {
   301  		return 1, nil
   302  	}
   303  	return comparisonType.Compare(a, b)
   304  }
   305  
   306  func (iter rangeHeapJoinIter) Len() int { return len(iter.activeRanges) }
   307  
   308  func (iter *rangeHeapJoinIter) Less(i, j int) bool {
   309  	lhs := iter.activeRanges[i][iter.rangeHeapPlan.MaxColumnIndex]
   310  	rhs := iter.activeRanges[j][iter.rangeHeapPlan.MaxColumnIndex]
   311  	// compareResult will be 0 if lhs==rhs, -1 if lhs < rhs, and +1 if lhs > rhs.
   312  	compareResult, err := compareNullsFirst(iter.rangeHeapPlan.ComparisonType, lhs, rhs)
   313  	if iter.err == nil && err != nil {
   314  		iter.err = err
   315  	}
   316  	return compareResult < 0
   317  }
   318  
   319  func (iter *rangeHeapJoinIter) Swap(i, j int) {
   320  	iter.activeRanges[i], iter.activeRanges[j] = iter.activeRanges[j], iter.activeRanges[i]
   321  }
   322  
   323  func (iter *rangeHeapJoinIter) Push(x any) {
   324  	item := x.(sql.Row)
   325  	iter.activeRanges = append(iter.activeRanges, item)
   326  }
   327  
   328  func (iter *rangeHeapJoinIter) Pop() any {
   329  	n := len(iter.activeRanges)
   330  	x := iter.activeRanges[n-1]
   331  	iter.activeRanges = iter.activeRanges[0 : n-1]
   332  	return x
   333  }
   334  
   335  func (iter *rangeHeapJoinIter) Peek() interface{} {
   336  	n := len(iter.activeRanges)
   337  	return iter.activeRanges[n-1][iter.rangeHeapPlan.MaxColumnIndex]
   338  }