github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/rowexec/mergejoiner.go (about)

     1  // Copyright 2016 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package rowexec
    12  
    13  import (
    14  	"context"
    15  	"fmt"
    16  
    17  	"github.com/cockroachdb/cockroach/pkg/sql/execinfra"
    18  	"github.com/cockroachdb/cockroach/pkg/sql/execinfrapb"
    19  	"github.com/cockroachdb/cockroach/pkg/sql/sqlbase"
    20  	"github.com/cockroachdb/cockroach/pkg/util"
    21  	"github.com/cockroachdb/cockroach/pkg/util/humanizeutil"
    22  	"github.com/cockroachdb/cockroach/pkg/util/tracing"
    23  	"github.com/cockroachdb/errors"
    24  	"github.com/opentracing/opentracing-go"
    25  )
    26  
    27  // mergeJoiner performs merge join, it has two input row sources with the same
    28  // ordering on the columns that have equality constraints.
    29  //
    30  // It is guaranteed that the results preserve this ordering.
    31  type mergeJoiner struct {
    32  	joinerBase
    33  
    34  	cancelChecker *sqlbase.CancelChecker
    35  
    36  	leftSource, rightSource execinfra.RowSource
    37  	leftRows, rightRows     []sqlbase.EncDatumRow
    38  	leftIdx, rightIdx       int
    39  	emitUnmatchedRight      bool
    40  	matchedRight            util.FastIntSet
    41  	matchedRightCount       int
    42  
    43  	streamMerger streamMerger
    44  }
    45  
    46  var _ execinfra.Processor = &mergeJoiner{}
    47  var _ execinfra.RowSource = &mergeJoiner{}
    48  var _ execinfra.OpNode = &mergeJoiner{}
    49  
    50  const mergeJoinerProcName = "merge joiner"
    51  
    52  func newMergeJoiner(
    53  	flowCtx *execinfra.FlowCtx,
    54  	processorID int32,
    55  	spec *execinfrapb.MergeJoinerSpec,
    56  	leftSource execinfra.RowSource,
    57  	rightSource execinfra.RowSource,
    58  	post *execinfrapb.PostProcessSpec,
    59  	output execinfra.RowReceiver,
    60  ) (*mergeJoiner, error) {
    61  	leftEqCols := make([]uint32, 0, len(spec.LeftOrdering.Columns))
    62  	rightEqCols := make([]uint32, 0, len(spec.RightOrdering.Columns))
    63  	for i, c := range spec.LeftOrdering.Columns {
    64  		if spec.RightOrdering.Columns[i].Direction != c.Direction {
    65  			return nil, errors.New("unmatched column orderings")
    66  		}
    67  		leftEqCols = append(leftEqCols, c.ColIdx)
    68  		rightEqCols = append(rightEqCols, spec.RightOrdering.Columns[i].ColIdx)
    69  	}
    70  
    71  	m := &mergeJoiner{
    72  		leftSource:  leftSource,
    73  		rightSource: rightSource,
    74  	}
    75  
    76  	if sp := opentracing.SpanFromContext(flowCtx.EvalCtx.Ctx()); sp != nil && tracing.IsRecording(sp) {
    77  		m.leftSource = newInputStatCollector(m.leftSource)
    78  		m.rightSource = newInputStatCollector(m.rightSource)
    79  		m.FinishTrace = m.outputStatsToTrace
    80  	}
    81  
    82  	if err := m.joinerBase.init(
    83  		m /* self */, flowCtx, processorID, leftSource.OutputTypes(), rightSource.OutputTypes(),
    84  		spec.Type, spec.OnExpr, leftEqCols, rightEqCols, 0, post, output,
    85  		execinfra.ProcStateOpts{
    86  			InputsToDrain: []execinfra.RowSource{leftSource, rightSource},
    87  			TrailingMetaCallback: func(context.Context) []execinfrapb.ProducerMetadata {
    88  				m.close()
    89  				return nil
    90  			},
    91  		},
    92  	); err != nil {
    93  		return nil, err
    94  	}
    95  
    96  	m.MemMonitor = execinfra.NewMonitor(flowCtx.EvalCtx.Ctx(), flowCtx.EvalCtx.Mon, "mergejoiner-mem")
    97  
    98  	var err error
    99  	m.streamMerger, err = makeStreamMerger(
   100  		m.leftSource,
   101  		execinfrapb.ConvertToColumnOrdering(spec.LeftOrdering),
   102  		m.rightSource,
   103  		execinfrapb.ConvertToColumnOrdering(spec.RightOrdering),
   104  		spec.NullEquality,
   105  		m.MemMonitor,
   106  	)
   107  	if err != nil {
   108  		return nil, err
   109  	}
   110  
   111  	return m, nil
   112  }
   113  
   114  // Start is part of the RowSource interface.
   115  func (m *mergeJoiner) Start(ctx context.Context) context.Context {
   116  	m.streamMerger.start(ctx)
   117  	ctx = m.StartInternal(ctx, mergeJoinerProcName)
   118  	m.cancelChecker = sqlbase.NewCancelChecker(ctx)
   119  	return ctx
   120  }
   121  
   122  // Next is part of the Processor interface.
   123  func (m *mergeJoiner) Next() (sqlbase.EncDatumRow, *execinfrapb.ProducerMetadata) {
   124  	for m.State == execinfra.StateRunning {
   125  		row, meta := m.nextRow()
   126  		if meta != nil {
   127  			if meta.Err != nil {
   128  				m.MoveToDraining(nil /* err */)
   129  			}
   130  			return nil, meta
   131  		}
   132  		if row == nil {
   133  			m.MoveToDraining(nil /* err */)
   134  			break
   135  		}
   136  
   137  		if outRow := m.ProcessRowHelper(row); outRow != nil {
   138  			return outRow, nil
   139  		}
   140  	}
   141  	return nil, m.DrainHelper()
   142  }
   143  
   144  func (m *mergeJoiner) nextRow() (sqlbase.EncDatumRow, *execinfrapb.ProducerMetadata) {
   145  	// The loops below form a restartable state machine that iterates over a
   146  	// batch of rows from the left and right side of the join. The state machine
   147  	// returns a result for every row that should be output.
   148  
   149  	for {
   150  		for m.leftIdx < len(m.leftRows) {
   151  			// We have unprocessed rows from the left-side batch.
   152  			lrow := m.leftRows[m.leftIdx]
   153  			for m.rightIdx < len(m.rightRows) {
   154  				// We have unprocessed rows from the right-side batch.
   155  				ridx := m.rightIdx
   156  				m.rightIdx++
   157  				renderedRow, err := m.render(lrow, m.rightRows[ridx])
   158  				if err != nil {
   159  					return nil, &execinfrapb.ProducerMetadata{Err: err}
   160  				}
   161  				if renderedRow != nil {
   162  					m.matchedRightCount++
   163  					if m.joinType == sqlbase.LeftAntiJoin || m.joinType == sqlbase.ExceptAllJoin {
   164  						break
   165  					}
   166  					if m.emitUnmatchedRight {
   167  						m.matchedRight.Add(ridx)
   168  					}
   169  					if m.joinType == sqlbase.LeftSemiJoin || m.joinType == sqlbase.IntersectAllJoin {
   170  						// Semi-joins and INTERSECT ALL only need to know if there is at
   171  						// least one match, so can skip the rest of the right rows.
   172  						m.rightIdx = len(m.rightRows)
   173  					}
   174  					return renderedRow, nil
   175  				}
   176  			}
   177  
   178  			// Perform the cancellation check. We don't perform this on every row,
   179  			// but once for every iteration through the right-side batch.
   180  			if err := m.cancelChecker.Check(); err != nil {
   181  				return nil, &execinfrapb.ProducerMetadata{Err: err}
   182  			}
   183  
   184  			// We've exhausted the right-side batch. Adjust the indexes for the next
   185  			// row from the left-side of the batch.
   186  			m.leftIdx++
   187  			m.rightIdx = 0
   188  
   189  			// For INTERSECT ALL and EXCEPT ALL, adjust rightIdx to skip all
   190  			// previously matched rows on the next right-side iteration, since we
   191  			// don't want to match them again.
   192  			if m.joinType.IsSetOpJoin() {
   193  				m.rightIdx = m.leftIdx
   194  			}
   195  
   196  			// If we didn't match any rows on the right-side of the batch and this is
   197  			// a left outer join, full outer join, anti join, or EXCEPT ALL, emit an
   198  			// unmatched left-side row.
   199  			if m.matchedRightCount == 0 && shouldEmitUnmatchedRow(leftSide, m.joinType) {
   200  				return m.renderUnmatchedRow(lrow, leftSide), nil
   201  			}
   202  
   203  			m.matchedRightCount = 0
   204  		}
   205  
   206  		// We've exhausted the left-side batch. If this is a right or full outer
   207  		// join (and thus matchedRight!=nil), emit unmatched right-side rows.
   208  		if m.emitUnmatchedRight {
   209  			for m.rightIdx < len(m.rightRows) {
   210  				ridx := m.rightIdx
   211  				m.rightIdx++
   212  				if m.matchedRight.Contains(ridx) {
   213  					continue
   214  				}
   215  				return m.renderUnmatchedRow(m.rightRows[ridx], rightSide), nil
   216  			}
   217  
   218  			m.matchedRight = util.FastIntSet{}
   219  			m.emitUnmatchedRight = false
   220  		}
   221  
   222  		// Retrieve the next batch of rows to process.
   223  		var meta *execinfrapb.ProducerMetadata
   224  		// TODO(paul): Investigate (with benchmarks) whether or not it's
   225  		// worthwhile to only buffer one row from the right stream per batch
   226  		// for semi-joins.
   227  		m.leftRows, m.rightRows, meta = m.streamMerger.NextBatch(m.Ctx, m.EvalCtx)
   228  		if meta != nil {
   229  			return nil, meta
   230  		}
   231  		if m.leftRows == nil && m.rightRows == nil {
   232  			return nil, nil
   233  		}
   234  
   235  		// Prepare for processing the next batch.
   236  		m.emitUnmatchedRight = shouldEmitUnmatchedRow(rightSide, m.joinType)
   237  		m.leftIdx, m.rightIdx = 0, 0
   238  	}
   239  }
   240  
   241  func (m *mergeJoiner) close() {
   242  	if m.InternalClose() {
   243  		ctx := m.Ctx
   244  		m.streamMerger.close(ctx)
   245  		m.MemMonitor.Stop(ctx)
   246  	}
   247  }
   248  
   249  // ConsumerClosed is part of the RowSource interface.
   250  func (m *mergeJoiner) ConsumerClosed() {
   251  	// The consumer is done, Next() will not be called again.
   252  	m.close()
   253  }
   254  
   255  var _ execinfrapb.DistSQLSpanStats = &MergeJoinerStats{}
   256  
   257  const mergeJoinerTagPrefix = "mergejoiner."
   258  
   259  // Stats implements the SpanStats interface.
   260  func (mjs *MergeJoinerStats) Stats() map[string]string {
   261  	// statsMap starts off as the left input stats map.
   262  	statsMap := mjs.LeftInputStats.Stats(mergeJoinerTagPrefix + "left.")
   263  	rightInputStatsMap := mjs.RightInputStats.Stats(mergeJoinerTagPrefix + "right.")
   264  	// Merge the two input maps.
   265  	for k, v := range rightInputStatsMap {
   266  		statsMap[k] = v
   267  	}
   268  	statsMap[mergeJoinerTagPrefix+MaxMemoryTagSuffix] = humanizeutil.IBytes(mjs.MaxAllocatedMem)
   269  	return statsMap
   270  }
   271  
   272  // StatsForQueryPlan implements the DistSQLSpanStats interface.
   273  func (mjs *MergeJoinerStats) StatsForQueryPlan() []string {
   274  	stats := append(
   275  		mjs.LeftInputStats.StatsForQueryPlan("left "),
   276  		mjs.RightInputStats.StatsForQueryPlan("right ")...,
   277  	)
   278  	if mjs.MaxAllocatedMem != 0 {
   279  		stats =
   280  			append(stats, fmt.Sprintf("%s: %s", MaxMemoryQueryPlanSuffix, humanizeutil.IBytes(mjs.MaxAllocatedMem)))
   281  	}
   282  	return stats
   283  }
   284  
   285  // outputStatsToTrace outputs the collected mergeJoiner stats to the trace. Will
   286  // fail silently if the mergeJoiner is not collecting stats.
   287  func (m *mergeJoiner) outputStatsToTrace() {
   288  	lis, ok := getInputStats(m.FlowCtx, m.leftSource)
   289  	if !ok {
   290  		return
   291  	}
   292  	ris, ok := getInputStats(m.FlowCtx, m.rightSource)
   293  	if !ok {
   294  		return
   295  	}
   296  	if sp := opentracing.SpanFromContext(m.Ctx); sp != nil {
   297  		tracing.SetSpanStats(
   298  			sp,
   299  			&MergeJoinerStats{
   300  				LeftInputStats:  lis,
   301  				RightInputStats: ris,
   302  				MaxAllocatedMem: m.MemMonitor.MaximumBytes(),
   303  			},
   304  		)
   305  	}
   306  }
   307  
   308  // ChildCount is part of the execinfra.OpNode interface.
   309  func (m *mergeJoiner) ChildCount(verbose bool) int {
   310  	if _, ok := m.leftSource.(execinfra.OpNode); ok {
   311  		if _, ok := m.rightSource.(execinfra.OpNode); ok {
   312  			return 2
   313  		}
   314  	}
   315  	return 0
   316  }
   317  
   318  // Child is part of the execinfra.OpNode interface.
   319  func (m *mergeJoiner) Child(nth int, verbose bool) execinfra.OpNode {
   320  	switch nth {
   321  	case 0:
   322  		if n, ok := m.leftSource.(execinfra.OpNode); ok {
   323  			return n
   324  		}
   325  		panic("left input to mergeJoiner is not an execinfra.OpNode")
   326  	case 1:
   327  		if n, ok := m.rightSource.(execinfra.OpNode); ok {
   328  			return n
   329  		}
   330  		panic("right input to mergeJoiner is not an execinfra.OpNode")
   331  	default:
   332  		panic(fmt.Sprintf("invalid index %d", nth))
   333  	}
   334  }