github.com/matrixorigin/matrixone@v1.2.0/pkg/sql/colexec/right/join.go (about)

     1  // Copyright 2021 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package right
    16  
    17  import (
    18  	"bytes"
    19  
    20  	"github.com/matrixorigin/matrixone/pkg/common/bitmap"
    21  	"github.com/matrixorigin/matrixone/pkg/common/moerr"
    22  
    23  	"github.com/matrixorigin/matrixone/pkg/common/hashmap"
    24  	"github.com/matrixorigin/matrixone/pkg/container/batch"
    25  	"github.com/matrixorigin/matrixone/pkg/container/vector"
    26  	"github.com/matrixorigin/matrixone/pkg/sql/colexec"
    27  	"github.com/matrixorigin/matrixone/pkg/vm"
    28  	"github.com/matrixorigin/matrixone/pkg/vm/process"
    29  )
    30  
    31  const argName = "right"
    32  
    33  func (arg *Argument) String(buf *bytes.Buffer) {
    34  	buf.WriteString(argName)
    35  	buf.WriteString(": right join ")
    36  }
    37  
    38  func (arg *Argument) Prepare(proc *process.Process) (err error) {
    39  	ap := arg
    40  	ap.ctr = new(container)
    41  	ap.ctr.InitReceiver(proc, false)
    42  	ap.ctr.inBuckets = make([]uint8, hashmap.UnitLimit)
    43  	ap.ctr.vecs = make([]*vector.Vector, len(ap.Conditions[0]))
    44  
    45  	ap.ctr.evecs = make([]evalVector, len(ap.Conditions[0]))
    46  	for i := range ap.Conditions[0] {
    47  		ap.ctr.evecs[i].executor, err = colexec.NewExpressionExecutor(proc, ap.Conditions[0][i])
    48  		if err != nil {
    49  			return err
    50  		}
    51  	}
    52  	if ap.Cond != nil {
    53  		ap.ctr.expr, err = colexec.NewExpressionExecutor(proc, ap.Cond)
    54  	}
    55  	ap.ctr.handledLast = false
    56  	return err
    57  }
    58  
    59  func (arg *Argument) Call(proc *process.Process) (vm.CallResult, error) {
    60  	if err, isCancel := vm.CancelCheck(proc); isCancel {
    61  		return vm.CancelResult, err
    62  	}
    63  
    64  	analyze := proc.GetAnalyze(arg.GetIdx(), arg.GetParallelIdx(), arg.GetParallelMajor())
    65  	analyze.Start()
    66  	defer analyze.Stop()
    67  	ap := arg
    68  	ctr := ap.ctr
    69  	result := vm.NewCallResult()
    70  	for {
    71  		switch ctr.state {
    72  		case Build:
    73  			if err := ctr.build(analyze); err != nil {
    74  				return result, err
    75  			}
    76  			if ctr.mp == nil && !arg.IsShuffle {
    77  				// for inner ,right and semi join, if hashmap is empty, we can finish this pipeline
    78  				// shuffle join can't stop early for this moment
    79  				ctr.state = End
    80  			} else {
    81  				ctr.state = Probe
    82  			}
    83  
    84  		case Probe:
    85  			if ap.bat == nil {
    86  				bat, _, err := ctr.ReceiveFromSingleReg(0, analyze)
    87  				if err != nil {
    88  					return result, err
    89  				}
    90  
    91  				if bat == nil {
    92  					ctr.state = SendLast
    93  					ap.rbat = nil
    94  					continue
    95  				}
    96  				if bat.IsEmpty() {
    97  					proc.PutBatch(bat)
    98  					continue
    99  				}
   100  				if ctr.mp == nil {
   101  					proc.PutBatch(bat)
   102  					continue
   103  				}
   104  				ap.bat = bat
   105  				ap.lastpos = 0
   106  			}
   107  
   108  			startrow := ap.lastpos
   109  			if err := ctr.probe(ap, proc, analyze, arg.GetIsFirst(), arg.GetIsLast(), &result); err != nil {
   110  				return result, err
   111  			}
   112  			if ap.lastpos == 0 {
   113  				proc.PutBatch(ap.bat)
   114  				ap.bat = nil
   115  			} else if ap.lastpos == startrow {
   116  				return result, moerr.NewInternalErrorNoCtx("right join hanging")
   117  			}
   118  			return result, nil
   119  
   120  		case SendLast:
   121  			setNil, err := ctr.sendLast(ap, proc, analyze, arg.GetIsFirst(), arg.GetIsLast(), &result)
   122  			if err != nil {
   123  				return result, err
   124  			}
   125  
   126  			ctr.state = End
   127  			if setNil {
   128  				continue
   129  			}
   130  			return result, nil
   131  
   132  		default:
   133  			result.Batch = nil
   134  			result.Status = vm.ExecStop
   135  			return result, nil
   136  		}
   137  	}
   138  }
   139  
   140  func (ctr *container) receiveHashMap(anal process.Analyze) error {
   141  	bat, _, err := ctr.ReceiveFromSingleReg(1, anal)
   142  	if err != nil {
   143  		return err
   144  	}
   145  	if bat != nil && bat.AuxData != nil {
   146  		ctr.mp = bat.DupJmAuxData()
   147  		ctr.maxAllocSize = max(ctr.maxAllocSize, ctr.mp.Size())
   148  	}
   149  	return nil
   150  }
   151  
   152  func (ctr *container) receiveBatch(anal process.Analyze) error {
   153  	for {
   154  		bat, _, err := ctr.ReceiveFromSingleReg(1, anal)
   155  		if err != nil {
   156  			return err
   157  		}
   158  		if bat != nil {
   159  			ctr.batchRowCount += bat.RowCount()
   160  			ctr.batches = append(ctr.batches, bat)
   161  		} else {
   162  			break
   163  		}
   164  	}
   165  	for i := 0; i < len(ctr.batches)-1; i++ {
   166  		if ctr.batches[i].RowCount() != colexec.DefaultBatchSize {
   167  			panic("wrong batch received for hash build!")
   168  		}
   169  	}
   170  	if ctr.batchRowCount > 0 {
   171  		ctr.matched = &bitmap.Bitmap{}
   172  		ctr.matched.InitWithSize(int64(ctr.batchRowCount))
   173  	}
   174  	return nil
   175  }
   176  
   177  func (ctr *container) build(anal process.Analyze) error {
   178  	err := ctr.receiveHashMap(anal)
   179  	if err != nil {
   180  		return err
   181  	}
   182  	return ctr.receiveBatch(anal)
   183  }
   184  
   185  func (ctr *container) sendLast(ap *Argument, proc *process.Process, analyze process.Analyze, _ bool, isLast bool, result *vm.CallResult) (bool, error) {
   186  	ctr.handledLast = true
   187  
   188  	if ctr.matched == nil {
   189  		return true, nil
   190  	}
   191  
   192  	if ap.NumCPU > 1 {
   193  		if !ap.IsMerger {
   194  			ap.Channel <- ctr.matched
   195  			return true, nil
   196  		} else {
   197  			cnt := 1
   198  			for v := range ap.Channel {
   199  				ctr.matched.Or(v)
   200  				cnt++
   201  				if cnt == int(ap.NumCPU) {
   202  					close(ap.Channel)
   203  					break
   204  				}
   205  			}
   206  		}
   207  	}
   208  
   209  	count := ctr.batchRowCount - ctr.matched.Count()
   210  	ctr.matched.Negate()
   211  	sels := make([]int32, 0, count)
   212  	itr := ctr.matched.Iterator()
   213  	for itr.HasNext() {
   214  		r := itr.Next()
   215  		sels = append(sels, int32(r))
   216  	}
   217  
   218  	if ctr.rbat != nil {
   219  		proc.PutBatch(ctr.rbat)
   220  		ctr.rbat = nil
   221  	}
   222  	ctr.rbat = batch.NewWithSize(len(ap.Result))
   223  
   224  	for i, rp := range ap.Result {
   225  		if rp.Rel == 0 {
   226  			ctr.rbat.Vecs[i] = proc.GetVector(ap.LeftTypes[rp.Pos])
   227  		} else {
   228  			ctr.rbat.Vecs[i] = proc.GetVector(ap.RightTypes[rp.Pos])
   229  		}
   230  	}
   231  
   232  	for i, rp := range ap.Result {
   233  		if rp.Rel == 0 {
   234  			if err := vector.AppendMultiFixed(ctr.rbat.Vecs[i], 0, true, count, proc.Mp()); err != nil {
   235  				return false, err
   236  			}
   237  		} else {
   238  			for _, sel := range sels {
   239  				idx1, idx2 := sel/colexec.DefaultBatchSize, sel%colexec.DefaultBatchSize
   240  				if err := ctr.rbat.Vecs[i].UnionOne(ctr.batches[idx1].Vecs[rp.Pos], int64(idx2), proc.Mp()); err != nil {
   241  					return false, err
   242  				}
   243  			}
   244  		}
   245  
   246  	}
   247  	ctr.rbat.AddRowCount(len(sels))
   248  	analyze.Output(ctr.rbat, isLast)
   249  	result.Batch = ctr.rbat
   250  	return false, nil
   251  }
   252  
   253  func (ctr *container) probe(ap *Argument, proc *process.Process, anal process.Analyze, isFirst bool, isLast bool, result *vm.CallResult) error {
   254  	anal.Input(ap.bat, isFirst)
   255  	if ctr.rbat != nil {
   256  		proc.PutBatch(ctr.rbat)
   257  		ctr.rbat = nil
   258  	}
   259  	ctr.rbat = batch.NewWithSize(len(ap.Result))
   260  	for i, rp := range ap.Result {
   261  		if rp.Rel == 0 {
   262  			ctr.rbat.Vecs[i] = proc.GetVector(*ap.bat.Vecs[rp.Pos].GetType())
   263  		} else {
   264  			ctr.rbat.Vecs[i] = proc.GetVector(ap.RightTypes[rp.Pos])
   265  		}
   266  	}
   267  
   268  	if err := ctr.evalJoinCondition(ap.bat, proc); err != nil {
   269  		return err
   270  	}
   271  	if ctr.joinBat1 == nil {
   272  		ctr.joinBat1, ctr.cfs1 = colexec.NewJoinBatch(ap.bat, proc.Mp())
   273  	}
   274  	if ctr.joinBat2 == nil {
   275  		ctr.joinBat2, ctr.cfs2 = colexec.NewJoinBatch(ctr.batches[0], proc.Mp())
   276  	}
   277  	count := ap.bat.RowCount()
   278  	mSels := ctr.mp.Sels()
   279  	itr := ctr.mp.NewIterator()
   280  
   281  	rowCountIncrese := 0
   282  	for i := ap.lastpos; i < count; i += hashmap.UnitLimit {
   283  		if rowCountIncrese >= colexec.DefaultBatchSize {
   284  			ctr.rbat.AddRowCount(rowCountIncrese)
   285  			anal.Output(ctr.rbat, isLast)
   286  			result.Batch = ctr.rbat
   287  			ap.lastpos = i
   288  			return nil
   289  		}
   290  		n := count - i
   291  		if n > hashmap.UnitLimit {
   292  			n = hashmap.UnitLimit
   293  		}
   294  		copy(ctr.inBuckets, hashmap.OneUInt8s)
   295  		vals, zvals := itr.Find(i, n, ctr.vecs, ctr.inBuckets)
   296  		for k := 0; k < n; k++ {
   297  			if ctr.inBuckets[k] == 0 || zvals[k] == 0 || vals[k] == 0 {
   298  				continue
   299  			}
   300  			if ap.HashOnPK {
   301  				idx1, idx2 := int64(vals[k]-1)/colexec.DefaultBatchSize, int64(vals[k]-1)%colexec.DefaultBatchSize
   302  				if ap.Cond != nil {
   303  					if err := colexec.SetJoinBatchValues(ctr.joinBat1, ap.bat, int64(i+k),
   304  						1, ctr.cfs1); err != nil {
   305  						return err
   306  					}
   307  					if err := colexec.SetJoinBatchValues(ctr.joinBat2, ctr.batches[idx1], idx2,
   308  						1, ctr.cfs2); err != nil {
   309  						return err
   310  					}
   311  					vec, err := ctr.expr.Eval(proc, []*batch.Batch{ctr.joinBat1, ctr.joinBat2})
   312  					if err != nil {
   313  						return err
   314  					}
   315  					if vec.IsConstNull() || vec.GetNulls().Contains(0) {
   316  						continue
   317  					}
   318  					bs := vector.MustFixedCol[bool](vec)
   319  					if bs[0] {
   320  						for j, rp := range ap.Result {
   321  							if rp.Rel == 0 {
   322  								if err := ctr.rbat.Vecs[j].UnionOne(ap.bat.Vecs[rp.Pos], int64(i+k), proc.Mp()); err != nil {
   323  									return err
   324  								}
   325  							} else {
   326  								if err := ctr.rbat.Vecs[j].UnionOne(ctr.batches[idx1].Vecs[rp.Pos], idx2, proc.Mp()); err != nil {
   327  									return err
   328  								}
   329  							}
   330  						}
   331  						ctr.matched.Add(vals[k] - 1)
   332  						rowCountIncrese++
   333  					}
   334  				} else {
   335  					for j, rp := range ap.Result {
   336  						if rp.Rel == 0 {
   337  							if err := ctr.rbat.Vecs[j].UnionMulti(ap.bat.Vecs[rp.Pos], int64(i+k), 1, proc.Mp()); err != nil {
   338  								return err
   339  							}
   340  						} else {
   341  							if err := ctr.rbat.Vecs[j].UnionOne(ctr.batches[idx1].Vecs[rp.Pos], idx2, proc.Mp()); err != nil {
   342  								return err
   343  							}
   344  						}
   345  					}
   346  					ctr.matched.Add(vals[k] - 1)
   347  					rowCountIncrese++
   348  				}
   349  			} else {
   350  				sels := mSels[vals[k]-1]
   351  				if ap.Cond != nil {
   352  					for _, sel := range sels {
   353  						idx1, idx2 := sel/colexec.DefaultBatchSize, sel%colexec.DefaultBatchSize
   354  						if err := colexec.SetJoinBatchValues(ctr.joinBat1, ap.bat, int64(i+k),
   355  							1, ctr.cfs1); err != nil {
   356  							return err
   357  						}
   358  						if err := colexec.SetJoinBatchValues(ctr.joinBat2, ctr.batches[idx1], int64(idx2),
   359  							1, ctr.cfs2); err != nil {
   360  							return err
   361  						}
   362  						vec, err := ctr.expr.Eval(proc, []*batch.Batch{ctr.joinBat1, ctr.joinBat2})
   363  						if err != nil {
   364  							return err
   365  						}
   366  						if vec.IsConstNull() || vec.GetNulls().Contains(0) {
   367  							continue
   368  						}
   369  						bs := vector.MustFixedCol[bool](vec)
   370  						if !bs[0] {
   371  							continue
   372  						}
   373  						for j, rp := range ap.Result {
   374  							if rp.Rel == 0 {
   375  								if err := ctr.rbat.Vecs[j].UnionOne(ap.bat.Vecs[rp.Pos], int64(i+k), proc.Mp()); err != nil {
   376  									return err
   377  								}
   378  							} else {
   379  								if err := ctr.rbat.Vecs[j].UnionOne(ctr.batches[idx1].Vecs[rp.Pos], int64(idx2), proc.Mp()); err != nil {
   380  									return err
   381  								}
   382  							}
   383  						}
   384  						ctr.matched.Add(uint64(sel))
   385  						rowCountIncrese++
   386  					}
   387  				} else {
   388  					for j, rp := range ap.Result {
   389  						if rp.Rel == 0 {
   390  							if err := ctr.rbat.Vecs[j].UnionMulti(ap.bat.Vecs[rp.Pos], int64(i+k), len(sels), proc.Mp()); err != nil {
   391  								return err
   392  							}
   393  						} else {
   394  							for _, sel := range sels {
   395  								idx1, idx2 := sel/colexec.DefaultBatchSize, sel%colexec.DefaultBatchSize
   396  								if err := ctr.rbat.Vecs[j].UnionOne(ctr.batches[idx1].Vecs[rp.Pos], int64(idx2), proc.Mp()); err != nil {
   397  									return err
   398  								}
   399  							}
   400  						}
   401  					}
   402  					for _, sel := range sels {
   403  						ctr.matched.Add(uint64(sel))
   404  					}
   405  					rowCountIncrese += len(sels)
   406  				}
   407  			}
   408  
   409  		}
   410  	}
   411  
   412  	ctr.rbat.AddRowCount(rowCountIncrese)
   413  	anal.Output(ctr.rbat, isLast)
   414  	result.Batch = ctr.rbat
   415  	ap.lastpos = 0
   416  	return nil
   417  }
   418  
   419  func (ctr *container) evalJoinCondition(bat *batch.Batch, proc *process.Process) error {
   420  	for i := range ctr.evecs {
   421  		vec, err := ctr.evecs[i].executor.Eval(proc, []*batch.Batch{bat})
   422  		if err != nil {
   423  			return err
   424  		}
   425  		ctr.vecs[i] = vec
   426  		ctr.evecs[i].vec = vec
   427  	}
   428  	return nil
   429  }