github.com/whtcorpsinc/milevadb-prod@v0.0.0-20211104133533-f57f4be3b597/interlock/merge_join.go (about)

     1  // Copyright 2020 WHTCORPS INC, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package interlock
    15  
    16  import (
    17  	"context"
    18  	"fmt"
    19  
    20  	"github.com/whtcorpsinc/failpoint"
    21  	"github.com/whtcorpsinc/milevadb/config"
    22  	"github.com/whtcorpsinc/milevadb/memex"
    23  	"github.com/whtcorpsinc/milevadb/stochastikctx/stmtctx"
    24  	"github.com/whtcorpsinc/milevadb/soliton/chunk"
    25  	"github.com/whtcorpsinc/milevadb/soliton/disk"
    26  	"github.com/whtcorpsinc/milevadb/soliton/memory"
    27  	"github.com/whtcorpsinc/milevadb/soliton/stringutil"
    28  )
    29  
    30  // MergeJoinInterDirc implements the merge join algorithm.
    31  // This operator assumes that two iterators of both sides
    32  // will provide required order on join condition:
    33  // 1. For equal-join, one of the join key from each side
    34  // matches the order given.
    35  // 2. For other cases its preferred not to use SMJ and operator
    36  // will throw error.
    37  type MergeJoinInterDirc struct {
    38  	baseInterlockingDirectorate
    39  
    40  	stmtCtx      *stmtctx.StatementContext
    41  	compareFuncs []memex.CompareFunc
    42  	joiner       joiner
    43  	isOuterJoin  bool
    44  	desc         bool
    45  
    46  	innerBlock *mergeJoinBlock
    47  	outerBlock *mergeJoinBlock
    48  
    49  	hasMatch bool
    50  	hasNull  bool
    51  
    52  	memTracker  *memory.Tracker
    53  	diskTracker *disk.Tracker
    54  }
    55  
    56  var (
    57  	innerBlockLabel fmt.Stringer = stringutil.StringerStr("innerBlock")
    58  	outerBlockLabel fmt.Stringer = stringutil.StringerStr("outerBlock")
    59  )
    60  
    61  type mergeJoinBlock struct {
    62  	isInner    bool
    63  	childIndex int
    64  	joinKeys   []*memex.DeferredCauset
    65  	filters    []memex.Expression
    66  
    67  	executed          bool
    68  	childChunk        *chunk.Chunk
    69  	childChunkIter    *chunk.Iterator4Chunk
    70  	groupChecker      *vecGroupChecker
    71  	groupEventsSelected []int
    72  	groupEventsIter     chunk.Iterator
    73  
    74  	// for inner causet, an unbroken group may refer many chunks
    75  	rowContainer *chunk.EventContainer
    76  
    77  	// for outer causet, save result of filters
    78  	filtersSelected []bool
    79  
    80  	memTracker *memory.Tracker
    81  }
    82  
    83  func (t *mergeJoinBlock) init(exec *MergeJoinInterDirc) {
    84  	child := exec.children[t.childIndex]
    85  	t.childChunk = newFirstChunk(child)
    86  	t.childChunkIter = chunk.NewIterator4Chunk(t.childChunk)
    87  
    88  	items := make([]memex.Expression, 0, len(t.joinKeys))
    89  	for _, defCaus := range t.joinKeys {
    90  		items = append(items, defCaus)
    91  	}
    92  	t.groupChecker = newVecGroupChecker(exec.ctx, items)
    93  	t.groupEventsIter = chunk.NewIterator4Chunk(t.childChunk)
    94  
    95  	if t.isInner {
    96  		t.rowContainer = chunk.NewEventContainer(child.base().retFieldTypes, t.childChunk.Capacity())
    97  		t.rowContainer.GetMemTracker().AttachTo(exec.memTracker)
    98  		t.rowContainer.GetMemTracker().SetLabel(memory.LabelForInnerBlock)
    99  		t.rowContainer.GetDiskTracker().AttachTo(exec.diskTracker)
   100  		t.rowContainer.GetDiskTracker().SetLabel(memory.LabelForInnerBlock)
   101  		if config.GetGlobalConfig().OOMUseTmpStorage {
   102  			actionSpill := t.rowContainer.CausetActionSpill()
   103  			failpoint.Inject("testMergeJoinEventContainerSpill", func(val failpoint.Value) {
   104  				if val.(bool) {
   105  					actionSpill = t.rowContainer.CausetActionSpillForTest()
   106  				}
   107  			})
   108  			exec.ctx.GetStochastikVars().StmtCtx.MemTracker.FallbackOldAndSetNewCausetAction(actionSpill)
   109  		}
   110  		t.memTracker = memory.NewTracker(memory.LabelForInnerBlock, -1)
   111  	} else {
   112  		t.filtersSelected = make([]bool, 0, exec.maxChunkSize)
   113  		t.memTracker = memory.NewTracker(memory.LabelForOuterBlock, -1)
   114  	}
   115  
   116  	t.memTracker.AttachTo(exec.memTracker)
   117  	t.memTracker.Consume(t.childChunk.MemoryUsage())
   118  }
   119  
   120  func (t *mergeJoinBlock) finish() error {
   121  	t.memTracker.Consume(-t.childChunk.MemoryUsage())
   122  
   123  	if t.isInner {
   124  		failpoint.Inject("testMergeJoinEventContainerSpill", func(val failpoint.Value) {
   125  			if val.(bool) {
   126  				actionSpill := t.rowContainer.CausetActionSpill()
   127  				actionSpill.WaitForTest()
   128  			}
   129  		})
   130  		if err := t.rowContainer.Close(); err != nil {
   131  			return err
   132  		}
   133  	}
   134  
   135  	t.executed = false
   136  	t.childChunk = nil
   137  	t.childChunkIter = nil
   138  	t.groupChecker = nil
   139  	t.groupEventsSelected = nil
   140  	t.groupEventsIter = nil
   141  	t.rowContainer = nil
   142  	t.filtersSelected = nil
   143  	t.memTracker = nil
   144  	return nil
   145  }
   146  
   147  func (t *mergeJoinBlock) selectNextGroup() {
   148  	t.groupEventsSelected = t.groupEventsSelected[:0]
   149  	begin, end := t.groupChecker.getNextGroup()
   150  	if t.isInner && t.hasNullInJoinKey(t.childChunk.GetEvent(begin)) {
   151  		return
   152  	}
   153  
   154  	for i := begin; i < end; i++ {
   155  		t.groupEventsSelected = append(t.groupEventsSelected, i)
   156  	}
   157  	t.childChunk.SetSel(t.groupEventsSelected)
   158  }
   159  
   160  func (t *mergeJoinBlock) fetchNextChunk(ctx context.Context, exec *MergeJoinInterDirc) error {
   161  	oldMemUsage := t.childChunk.MemoryUsage()
   162  	err := Next(ctx, exec.children[t.childIndex], t.childChunk)
   163  	t.memTracker.Consume(t.childChunk.MemoryUsage() - oldMemUsage)
   164  	if err != nil {
   165  		return err
   166  	}
   167  	t.executed = t.childChunk.NumEvents() == 0
   168  	return nil
   169  }
   170  
   171  func (t *mergeJoinBlock) fetchNextInnerGroup(ctx context.Context, exec *MergeJoinInterDirc) error {
   172  	t.childChunk.SetSel(nil)
   173  	if err := t.rowContainer.Reset(); err != nil {
   174  		return err
   175  	}
   176  
   177  fetchNext:
   178  	if t.executed && t.groupChecker.isExhausted() {
   179  		// Ensure iter at the end, since sel of childChunk has been cleared.
   180  		t.groupEventsIter.ReachEnd()
   181  		return nil
   182  	}
   183  
   184  	isEmpty := true
   185  	// For inner causet, rows have null in join keys should be skip by selectNextGroup.
   186  	for isEmpty && !t.groupChecker.isExhausted() {
   187  		t.selectNextGroup()
   188  		isEmpty = len(t.groupEventsSelected) == 0
   189  	}
   190  
   191  	// For inner causet, all the rows have the same join keys should be put into one group.
   192  	for !t.executed && t.groupChecker.isExhausted() {
   193  		if !isEmpty {
   194  			// Group is not empty, hand over the management of childChunk to t.rowContainer.
   195  			if err := t.rowContainer.Add(t.childChunk); err != nil {
   196  				return err
   197  			}
   198  			t.memTracker.Consume(-t.childChunk.MemoryUsage())
   199  			t.groupEventsSelected = nil
   200  
   201  			t.childChunk = t.rowContainer.AllocChunk()
   202  			t.childChunkIter = chunk.NewIterator4Chunk(t.childChunk)
   203  			t.memTracker.Consume(t.childChunk.MemoryUsage())
   204  		}
   205  
   206  		if err := t.fetchNextChunk(ctx, exec); err != nil {
   207  			return err
   208  		}
   209  		if t.executed {
   210  			break
   211  		}
   212  
   213  		isFirstGroupSameAsPrev, err := t.groupChecker.splitIntoGroups(t.childChunk)
   214  		if err != nil {
   215  			return err
   216  		}
   217  		if isFirstGroupSameAsPrev && !isEmpty {
   218  			t.selectNextGroup()
   219  		}
   220  	}
   221  	if isEmpty {
   222  		goto fetchNext
   223  	}
   224  
   225  	// iterate all data in t.rowContainer and t.childChunk
   226  	var iter chunk.Iterator
   227  	if t.rowContainer.NumChunks() != 0 {
   228  		iter = chunk.NewIterator4EventContainer(t.rowContainer)
   229  	}
   230  	if len(t.groupEventsSelected) != 0 {
   231  		if iter != nil {
   232  			iter = chunk.NewMultiIterator(iter, t.childChunkIter)
   233  		} else {
   234  			iter = t.childChunkIter
   235  		}
   236  	}
   237  	t.groupEventsIter = iter
   238  	t.groupEventsIter.Begin()
   239  	return nil
   240  }
   241  
   242  func (t *mergeJoinBlock) fetchNextOuterGroup(ctx context.Context, exec *MergeJoinInterDirc, requiredEvents int) error {
   243  	if t.executed && t.groupChecker.isExhausted() {
   244  		return nil
   245  	}
   246  
   247  	if !t.executed && t.groupChecker.isExhausted() {
   248  		// It's hard to calculate selectivity if there is any filter or it's inner join,
   249  		// so we just push the requiredEvents down when it's outer join and has no filter.
   250  		if exec.isOuterJoin && len(t.filters) == 0 {
   251  			t.childChunk.SetRequiredEvents(requiredEvents, exec.maxChunkSize)
   252  		}
   253  		err := t.fetchNextChunk(ctx, exec)
   254  		if err != nil || t.executed {
   255  			return err
   256  		}
   257  
   258  		t.childChunkIter.Begin()
   259  		t.filtersSelected, err = memex.VectorizedFilter(exec.ctx, t.filters, t.childChunkIter, t.filtersSelected)
   260  		if err != nil {
   261  			return err
   262  		}
   263  
   264  		_, err = t.groupChecker.splitIntoGroups(t.childChunk)
   265  		if err != nil {
   266  			return err
   267  		}
   268  	}
   269  
   270  	t.selectNextGroup()
   271  	t.groupEventsIter.Begin()
   272  	return nil
   273  }
   274  
   275  func (t *mergeJoinBlock) hasNullInJoinKey(event chunk.Event) bool {
   276  	for _, defCaus := range t.joinKeys {
   277  		ordinal := defCaus.Index
   278  		if event.IsNull(ordinal) {
   279  			return true
   280  		}
   281  	}
   282  	return false
   283  }
   284  
   285  // Close implements the InterlockingDirectorate Close interface.
   286  func (e *MergeJoinInterDirc) Close() error {
   287  	if err := e.innerBlock.finish(); err != nil {
   288  		return err
   289  	}
   290  	if err := e.outerBlock.finish(); err != nil {
   291  		return err
   292  	}
   293  
   294  	e.hasMatch = false
   295  	e.hasNull = false
   296  	e.memTracker = nil
   297  	e.diskTracker = nil
   298  	return e.baseInterlockingDirectorate.Close()
   299  }
   300  
   301  // Open implements the InterlockingDirectorate Open interface.
   302  func (e *MergeJoinInterDirc) Open(ctx context.Context) error {
   303  	if err := e.baseInterlockingDirectorate.Open(ctx); err != nil {
   304  		return err
   305  	}
   306  
   307  	e.memTracker = memory.NewTracker(e.id, -1)
   308  	e.memTracker.AttachTo(e.ctx.GetStochastikVars().StmtCtx.MemTracker)
   309  	e.diskTracker = disk.NewTracker(e.id, -1)
   310  	e.diskTracker.AttachTo(e.ctx.GetStochastikVars().StmtCtx.DiskTracker)
   311  
   312  	e.innerBlock.init(e)
   313  	e.outerBlock.init(e)
   314  	return nil
   315  }
   316  
   317  // Next implements the InterlockingDirectorate Next interface.
   318  // Note the inner group defCauslects all identical keys in a group across multiple chunks, but the outer group just covers
   319  // the identical keys within a chunk, so identical keys may cover more than one chunk.
   320  func (e *MergeJoinInterDirc) Next(ctx context.Context, req *chunk.Chunk) (err error) {
   321  	req.Reset()
   322  
   323  	innerIter := e.innerBlock.groupEventsIter
   324  	outerIter := e.outerBlock.groupEventsIter
   325  	for !req.IsFull() {
   326  		if innerIter.Current() == innerIter.End() {
   327  			if err := e.innerBlock.fetchNextInnerGroup(ctx, e); err != nil {
   328  				return err
   329  			}
   330  			innerIter = e.innerBlock.groupEventsIter
   331  		}
   332  		if outerIter.Current() == outerIter.End() {
   333  			if err := e.outerBlock.fetchNextOuterGroup(ctx, e, req.RequiredEvents()-req.NumEvents()); err != nil {
   334  				return err
   335  			}
   336  			outerIter = e.outerBlock.groupEventsIter
   337  			if e.outerBlock.executed {
   338  				return nil
   339  			}
   340  		}
   341  
   342  		cmpResult := -1
   343  		if e.desc {
   344  			cmpResult = 1
   345  		}
   346  		if innerIter.Current() != innerIter.End() {
   347  			cmpResult, err = e.compare(outerIter.Current(), innerIter.Current())
   348  			if err != nil {
   349  				return err
   350  			}
   351  		}
   352  		// the inner group falls behind
   353  		if (cmpResult > 0 && !e.desc) || (cmpResult < 0 && e.desc) {
   354  			innerIter.ReachEnd()
   355  			continue
   356  		}
   357  		// the outer group falls behind
   358  		if (cmpResult < 0 && !e.desc) || (cmpResult > 0 && e.desc) {
   359  			for event := outerIter.Current(); event != outerIter.End() && !req.IsFull(); event = outerIter.Next() {
   360  				e.joiner.onMissMatch(false, event, req)
   361  			}
   362  			continue
   363  		}
   364  
   365  		for event := outerIter.Current(); event != outerIter.End() && !req.IsFull(); event = outerIter.Next() {
   366  			if !e.outerBlock.filtersSelected[event.Idx()] {
   367  				e.joiner.onMissMatch(false, event, req)
   368  				continue
   369  			}
   370  			// compare each outer item with each inner item
   371  			// the inner maybe not exhausted at one time
   372  			for innerIter.Current() != innerIter.End() {
   373  				matched, isNull, err := e.joiner.tryToMatchInners(event, innerIter, req)
   374  				if err != nil {
   375  					return err
   376  				}
   377  				e.hasMatch = e.hasMatch || matched
   378  				e.hasNull = e.hasNull || isNull
   379  				if req.IsFull() {
   380  					if innerIter.Current() == innerIter.End() {
   381  						break
   382  					}
   383  					return nil
   384  				}
   385  			}
   386  
   387  			if !e.hasMatch {
   388  				e.joiner.onMissMatch(e.hasNull, event, req)
   389  			}
   390  			e.hasMatch = false
   391  			e.hasNull = false
   392  			innerIter.Begin()
   393  		}
   394  	}
   395  	return nil
   396  }
   397  
   398  func (e *MergeJoinInterDirc) compare(outerEvent, innerEvent chunk.Event) (int, error) {
   399  	outerJoinKeys := e.outerBlock.joinKeys
   400  	innerJoinKeys := e.innerBlock.joinKeys
   401  	for i := range outerJoinKeys {
   402  		cmp, _, err := e.compareFuncs[i](e.ctx, outerJoinKeys[i], innerJoinKeys[i], outerEvent, innerEvent)
   403  		if err != nil {
   404  			return 0, err
   405  		}
   406  
   407  		if cmp != 0 {
   408  			return int(cmp), nil
   409  		}
   410  	}
   411  	return 0, nil
   412  }