github.com/matrixorigin/matrixone@v1.2.0/pkg/sql/colexec/rightsemi/join.go (about)

     1  // Copyright 2021 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package rightsemi
    16  
    17  import (
    18  	"bytes"
    19  	"time"
    20  
    21  	"github.com/matrixorigin/matrixone/pkg/common/bitmap"
    22  	"github.com/matrixorigin/matrixone/pkg/common/hashmap"
    23  	"github.com/matrixorigin/matrixone/pkg/common/moerr"
    24  	"github.com/matrixorigin/matrixone/pkg/container/batch"
    25  	"github.com/matrixorigin/matrixone/pkg/container/vector"
    26  	"github.com/matrixorigin/matrixone/pkg/sql/colexec"
    27  	"github.com/matrixorigin/matrixone/pkg/vm"
    28  	"github.com/matrixorigin/matrixone/pkg/vm/process"
    29  )
    30  
    31  const argName = "right_semi"
    32  
    33  func (arg *Argument) String(buf *bytes.Buffer) {
    34  	buf.WriteString(argName)
    35  	buf.WriteString(": right semi join ")
    36  }
    37  
    38  func (arg *Argument) Prepare(proc *process.Process) (err error) {
    39  	ap := arg
    40  	ap.ctr = new(container)
    41  	ap.ctr.InitReceiver(proc, false)
    42  	ap.ctr.inBuckets = make([]uint8, hashmap.UnitLimit)
    43  	ap.ctr.vecs = make([]*vector.Vector, len(ap.Conditions[0]))
    44  	ap.ctr.evecs = make([]evalVector, len(ap.Conditions[0]))
    45  	for i := range ap.ctr.evecs {
    46  		ap.ctr.evecs[i].executor, err = colexec.NewExpressionExecutor(proc, ap.Conditions[0][i])
    47  		if err != nil {
    48  			return err
    49  		}
    50  	}
    51  
    52  	if ap.Cond != nil {
    53  		ap.ctr.expr, err = colexec.NewExpressionExecutor(proc, ap.Cond)
    54  	}
    55  	ap.ctr.tmpBatches = make([]*batch.Batch, 2)
    56  	return err
    57  }
    58  
    59  func (arg *Argument) Call(proc *process.Process) (vm.CallResult, error) {
    60  	if err, isCancel := vm.CancelCheck(proc); isCancel {
    61  		return vm.CancelResult, err
    62  	}
    63  
    64  	analyze := proc.GetAnalyze(arg.GetIdx(), arg.GetParallelIdx(), arg.GetParallelMajor())
    65  	analyze.Start()
    66  	defer analyze.Stop()
    67  	ap := arg
    68  	ctr := ap.ctr
    69  	result := vm.NewCallResult()
    70  	for {
    71  		switch ctr.state {
    72  		case Build:
    73  			if err := ctr.build(analyze); err != nil {
    74  				return result, err
    75  			}
    76  			if ctr.mp == nil && !arg.IsShuffle {
    77  				// for inner ,right and semi join, if hashmap is empty, we can finish this pipeline
    78  				// shuffle join can't stop early for this moment
    79  				ctr.state = End
    80  			} else {
    81  				ctr.state = Probe
    82  			}
    83  
    84  		case Probe:
    85  			bat, _, err := ctr.ReceiveFromSingleReg(0, analyze)
    86  			if err != nil {
    87  				return result, err
    88  			}
    89  
    90  			if bat == nil {
    91  				ctr.state = SendLast
    92  				continue
    93  			}
    94  			if bat.IsEmpty() {
    95  				proc.PutBatch(bat)
    96  				continue
    97  			}
    98  
    99  			if ctr.batchRowCount == 0 {
   100  				proc.PutBatch(bat)
   101  				continue
   102  			}
   103  
   104  			if err = ctr.probe(bat, ap, proc, analyze, arg.GetIsFirst(), arg.GetIsLast()); err != nil {
   105  				bat.Clean(proc.Mp())
   106  				return result, err
   107  			}
   108  			proc.PutBatch(bat)
   109  			continue
   110  
   111  		case SendLast:
   112  			if ap.rbat == nil {
   113  				ap.lastpos = 0
   114  				setNil, err := ctr.sendLast(ap, proc, analyze, arg.GetIsFirst(), arg.GetIsLast())
   115  				if err != nil {
   116  					return result, err
   117  				}
   118  				if setNil {
   119  					ctr.state = End
   120  				}
   121  				continue
   122  			} else {
   123  				if ap.lastpos >= len(ap.rbat) {
   124  					ctr.state = End
   125  					continue
   126  				}
   127  				result.Batch = ap.rbat[ap.lastpos]
   128  				ap.lastpos++
   129  				return result, nil
   130  			}
   131  
   132  		default:
   133  			result.Batch = nil
   134  			result.Status = vm.ExecStop
   135  			return result, nil
   136  		}
   137  	}
   138  }
   139  
   140  func (ctr *container) receiveHashMap(anal process.Analyze) error {
   141  	bat, _, err := ctr.ReceiveFromSingleReg(1, anal)
   142  	if err != nil {
   143  		return err
   144  	}
   145  	if bat != nil && bat.AuxData != nil {
   146  		ctr.mp = bat.DupJmAuxData()
   147  		ctr.maxAllocSize = max(ctr.maxAllocSize, ctr.mp.Size())
   148  	}
   149  	return nil
   150  }
   151  
   152  func (ctr *container) receiveBatch(anal process.Analyze) error {
   153  	for {
   154  		bat, _, err := ctr.ReceiveFromSingleReg(1, anal)
   155  		if err != nil {
   156  			return err
   157  		}
   158  		if bat != nil {
   159  			ctr.batchRowCount += bat.RowCount()
   160  			ctr.batches = append(ctr.batches, bat)
   161  		} else {
   162  			break
   163  		}
   164  	}
   165  	for i := 0; i < len(ctr.batches)-1; i++ {
   166  		if ctr.batches[i].RowCount() != colexec.DefaultBatchSize {
   167  			panic("wrong batch received for hash build!")
   168  		}
   169  	}
   170  	if ctr.batchRowCount > 0 {
   171  		ctr.matched = &bitmap.Bitmap{}
   172  		ctr.matched.InitWithSize(int64(ctr.batchRowCount))
   173  	}
   174  	return nil
   175  }
   176  
   177  func (ctr *container) build(anal process.Analyze) error {
   178  	err := ctr.receiveHashMap(anal)
   179  	if err != nil {
   180  		return err
   181  	}
   182  	return ctr.receiveBatch(anal)
   183  }
   184  
   185  func (ctr *container) sendLast(ap *Argument, proc *process.Process, analyze process.Analyze, _ bool, isLast bool) (bool, error) {
   186  	ctr.handledLast = true
   187  
   188  	if ctr.matched == nil {
   189  		return true, nil
   190  	}
   191  
   192  	if ap.NumCPU > 1 {
   193  		if !ap.IsMerger {
   194  
   195  			sendStart := time.Now()
   196  			ap.Channel <- ctr.matched
   197  			analyze.WaitStop(sendStart)
   198  
   199  			return true, nil
   200  		} else {
   201  			cnt := 1
   202  
   203  			receiveStart := time.Now()
   204  
   205  			// The original code didn't handle the context correctly and would cause the system to HUNG!
   206  			for completed := true; completed; {
   207  				select {
   208  				case <-proc.Ctx.Done():
   209  					return true, moerr.NewInternalError(proc.Ctx, "query has been closed early")
   210  				case v := <-ap.Channel:
   211  					ctr.matched.Or(v)
   212  					cnt++
   213  					if cnt == int(ap.NumCPU) {
   214  						close(ap.Channel)
   215  						completed = false
   216  					}
   217  				}
   218  			}
   219  			analyze.WaitStop(receiveStart)
   220  
   221  		}
   222  	}
   223  
   224  	count := ctr.matched.Count()
   225  	sels := make([]int32, 0, count)
   226  	itr := ctr.matched.Iterator()
   227  	for itr.HasNext() {
   228  		r := itr.Next()
   229  		sels = append(sels, int32(r))
   230  	}
   231  
   232  	if len(sels) <= colexec.DefaultBatchSize {
   233  		if ctr.rbat != nil {
   234  			proc.PutBatch(ctr.rbat)
   235  			ctr.rbat = nil
   236  		}
   237  		ctr.rbat = batch.NewWithSize(len(ap.Result))
   238  
   239  		for i, pos := range ap.Result {
   240  			ctr.rbat.Vecs[i] = proc.GetVector(ap.RightTypes[pos])
   241  		}
   242  		for j, pos := range ap.Result {
   243  			for _, sel := range sels {
   244  				idx1, idx2 := sel/colexec.DefaultBatchSize, sel%colexec.DefaultBatchSize
   245  				if err := ctr.rbat.Vecs[j].UnionOne(ctr.batches[idx1].Vecs[pos], int64(idx2), proc.Mp()); err != nil {
   246  					return false, err
   247  				}
   248  			}
   249  		}
   250  		ctr.rbat.AddRowCount(len(sels))
   251  
   252  		analyze.Output(ctr.rbat, isLast)
   253  		ap.rbat = []*batch.Batch{ctr.rbat}
   254  		return false, nil
   255  	} else {
   256  		n := (len(sels)-1)/colexec.DefaultBatchSize + 1
   257  		ap.rbat = make([]*batch.Batch, n)
   258  		for k := range ap.rbat {
   259  			ap.rbat[k] = batch.NewWithSize(len(ap.Result))
   260  			for i, pos := range ap.Result {
   261  				ap.rbat[k].Vecs[i] = proc.GetVector(ap.RightTypes[pos])
   262  			}
   263  			var newsels []int32
   264  			if (k+1)*colexec.DefaultBatchSize <= len(sels) {
   265  				newsels = sels[k*colexec.DefaultBatchSize : (k+1)*colexec.DefaultBatchSize]
   266  			} else {
   267  				newsels = sels[k*colexec.DefaultBatchSize:]
   268  			}
   269  			for i, pos := range ap.Result {
   270  				for _, sel := range newsels {
   271  					idx1, idx2 := sel/colexec.DefaultBatchSize, sel%colexec.DefaultBatchSize
   272  					if err := ap.rbat[k].Vecs[i].UnionOne(ctr.batches[idx1].Vecs[pos], int64(idx2), proc.Mp()); err != nil {
   273  						return false, err
   274  					}
   275  				}
   276  			}
   277  			ap.rbat[k].SetRowCount(len(newsels))
   278  			analyze.Output(ap.rbat[k], isLast)
   279  		}
   280  		return false, nil
   281  	}
   282  
   283  }
   284  
   285  func (ctr *container) probe(bat *batch.Batch, ap *Argument, proc *process.Process, analyze process.Analyze, isFirst bool, _ bool) error {
   286  	analyze.Input(bat, isFirst)
   287  
   288  	if err := ctr.evalJoinCondition(bat, proc); err != nil {
   289  		return err
   290  	}
   291  	if ctr.joinBat1 == nil {
   292  		ctr.joinBat1, ctr.cfs1 = colexec.NewJoinBatch(bat, proc.Mp())
   293  	}
   294  	if ctr.joinBat2 == nil {
   295  		ctr.joinBat2, ctr.cfs2 = colexec.NewJoinBatch(ctr.batches[0], proc.Mp())
   296  	}
   297  	count := bat.RowCount()
   298  	mSels := ctr.mp.Sels()
   299  	itr := ctr.mp.NewIterator()
   300  	for i := 0; i < count; i += hashmap.UnitLimit {
   301  		n := count - i
   302  		if n > hashmap.UnitLimit {
   303  			n = hashmap.UnitLimit
   304  		}
   305  		copy(ctr.inBuckets, hashmap.OneUInt8s)
   306  		vals, zvals := itr.Find(i, n, ctr.vecs, ctr.inBuckets)
   307  		for k := 0; k < n; k++ {
   308  			if ctr.inBuckets[k] == 0 || zvals[k] == 0 || vals[k] == 0 {
   309  				continue
   310  			}
   311  			if ap.HashOnPK {
   312  				idx1, idx2 := int64(vals[k]-1)/colexec.DefaultBatchSize, int64(vals[k]-1)%colexec.DefaultBatchSize
   313  				if ctr.matched.Contains(vals[k] - 1) {
   314  					continue
   315  				}
   316  				if ap.Cond != nil {
   317  					if err := colexec.SetJoinBatchValues(ctr.joinBat1, bat, int64(i+k),
   318  						1, ctr.cfs1); err != nil {
   319  						return err
   320  					}
   321  					if err := colexec.SetJoinBatchValues(ctr.joinBat2, ctr.batches[idx1], idx2,
   322  						1, ctr.cfs2); err != nil {
   323  						return err
   324  					}
   325  					ctr.tmpBatches[0] = ctr.joinBat1
   326  					ctr.tmpBatches[1] = ctr.joinBat2
   327  					vec, err := ctr.expr.Eval(proc, ctr.tmpBatches)
   328  					if err != nil {
   329  						return err
   330  					}
   331  					if vec.IsConstNull() || vec.GetNulls().Contains(0) {
   332  						continue
   333  					} else {
   334  						vcol := vector.MustFixedCol[bool](vec)
   335  						if !vcol[0] {
   336  							continue
   337  						}
   338  					}
   339  				}
   340  				ctr.matched.Add(vals[k] - 1)
   341  			} else {
   342  				sels := mSels[vals[k]-1]
   343  				for _, sel := range sels {
   344  					if ctr.matched.Contains(uint64(sel)) {
   345  						continue
   346  					}
   347  					idx1, idx2 := sel/colexec.DefaultBatchSize, sel%colexec.DefaultBatchSize
   348  					if ap.Cond != nil {
   349  						if err := colexec.SetJoinBatchValues(ctr.joinBat1, bat, int64(i+k),
   350  							1, ctr.cfs1); err != nil {
   351  							return err
   352  						}
   353  						if err := colexec.SetJoinBatchValues(ctr.joinBat2, ctr.batches[idx1], int64(idx2),
   354  							1, ctr.cfs2); err != nil {
   355  							return err
   356  						}
   357  						ctr.tmpBatches[0] = ctr.joinBat1
   358  						ctr.tmpBatches[1] = ctr.joinBat2
   359  						vec, err := ctr.expr.Eval(proc, ctr.tmpBatches)
   360  						if err != nil {
   361  							return err
   362  						}
   363  						if vec.IsConstNull() || vec.GetNulls().Contains(0) {
   364  							continue
   365  						} else {
   366  							vcol := vector.MustFixedCol[bool](vec)
   367  							if !vcol[0] {
   368  								continue
   369  							}
   370  						}
   371  					}
   372  					ctr.matched.Add(uint64(sel))
   373  				}
   374  			}
   375  
   376  		}
   377  	}
   378  	return nil
   379  }
   380  
   381  func (ctr *container) evalJoinCondition(bat *batch.Batch, proc *process.Process) error {
   382  	for i := range ctr.evecs {
   383  		vec, err := ctr.evecs[i].executor.Eval(proc, []*batch.Batch{bat})
   384  		if err != nil {
   385  			return err
   386  		}
   387  		ctr.vecs[i] = vec
   388  		ctr.evecs[i].vec = vec
   389  	}
   390  	return nil
   391  }