github.com/matrixorigin/matrixone@v1.2.0/pkg/sql/colexec/rightanti/join.go (about)

     1  // Copyright 2021 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package rightanti
    16  
    17  import (
    18  	"bytes"
    19  
    20  	"github.com/matrixorigin/matrixone/pkg/common/bitmap"
    21  
    22  	"github.com/matrixorigin/matrixone/pkg/common/hashmap"
    23  	"github.com/matrixorigin/matrixone/pkg/container/batch"
    24  	"github.com/matrixorigin/matrixone/pkg/container/vector"
    25  	"github.com/matrixorigin/matrixone/pkg/sql/colexec"
    26  	"github.com/matrixorigin/matrixone/pkg/vm"
    27  	"github.com/matrixorigin/matrixone/pkg/vm/process"
    28  )
    29  
    30  const argName = "right_anti"
    31  
    32  func (arg *Argument) String(buf *bytes.Buffer) {
    33  	buf.WriteString(argName)
    34  	buf.WriteString(": right anti join ")
    35  }
    36  
    37  func (arg *Argument) Prepare(proc *process.Process) (err error) {
    38  	ap := arg
    39  	ap.ctr = new(container)
    40  	ap.ctr.InitReceiver(proc, false)
    41  	ap.ctr.inBuckets = make([]uint8, hashmap.UnitLimit)
    42  	ap.ctr.vecs = make([]*vector.Vector, len(ap.Conditions[0]))
    43  	ap.ctr.evecs = make([]evalVector, len(ap.Conditions[0]))
    44  	for i := range ap.ctr.evecs {
    45  		ap.ctr.evecs[i].executor, err = colexec.NewExpressionExecutor(proc, ap.Conditions[0][i])
    46  		if err != nil {
    47  			return err
    48  		}
    49  	}
    50  	if ap.Cond != nil {
    51  		ap.ctr.expr, err = colexec.NewExpressionExecutor(proc, ap.Cond)
    52  	}
    53  
    54  	ap.ctr.tmpBatches = make([]*batch.Batch, 2)
    55  	return err
    56  }
    57  
    58  func (arg *Argument) Call(proc *process.Process) (vm.CallResult, error) {
    59  	if err, isCancel := vm.CancelCheck(proc); isCancel {
    60  		return vm.CancelResult, err
    61  	}
    62  
    63  	analyze := proc.GetAnalyze(arg.GetIdx(), arg.GetParallelIdx(), arg.GetParallelMajor())
    64  	analyze.Start()
    65  	defer analyze.Stop()
    66  	ap := arg
    67  	ctr := ap.ctr
    68  	result := vm.NewCallResult()
    69  	for {
    70  		switch ctr.state {
    71  		case Build:
    72  			if err := ctr.build(analyze); err != nil {
    73  				return result, err
    74  			}
    75  			// for inner ,right and semi join, if hashmap is empty, we can finish this pipeline
    76  			// shuffle join can't stop early for this moment
    77  			if ctr.mp == nil && !arg.IsShuffle {
    78  				ctr.state = End
    79  			} else {
    80  				ctr.state = Probe
    81  			}
    82  
    83  		case Probe:
    84  			bat, _, err := ctr.ReceiveFromSingleReg(0, analyze)
    85  			if err != nil {
    86  				return result, err
    87  			}
    88  
    89  			if bat == nil {
    90  				ctr.state = SendLast
    91  				ap.rbat = nil
    92  				continue
    93  			}
    94  			if bat.IsEmpty() {
    95  				proc.PutBatch(bat)
    96  				continue
    97  			}
    98  
    99  			if ctr.batchRowCount == 0 {
   100  				proc.PutBatch(bat)
   101  				continue
   102  			}
   103  
   104  			if err := ctr.probe(bat, ap, proc, analyze, arg.GetIsFirst(), arg.GetIsLast()); err != nil {
   105  				return result, err
   106  			}
   107  
   108  			continue
   109  
   110  		case SendLast:
   111  			if ap.rbat == nil {
   112  				ap.lastpos = 0
   113  				setNil, err := ctr.sendLast(ap, proc, analyze, arg.GetIsFirst(), arg.GetIsLast())
   114  				if err != nil {
   115  					return result, err
   116  				}
   117  				if setNil {
   118  					ctr.state = End
   119  				}
   120  				continue
   121  			} else {
   122  				if ap.lastpos >= len(ap.rbat) {
   123  					ctr.state = End
   124  					continue
   125  				}
   126  				result.Batch = ap.rbat[ap.lastpos]
   127  				ap.lastpos++
   128  				return result, nil
   129  			}
   130  
   131  		default:
   132  			result.Batch = nil
   133  			result.Status = vm.ExecStop
   134  			return result, nil
   135  		}
   136  	}
   137  }
   138  
   139  func (ctr *container) receiveHashMap(anal process.Analyze) error {
   140  	bat, _, err := ctr.ReceiveFromSingleReg(1, anal)
   141  	if err != nil {
   142  		return err
   143  	}
   144  	if bat != nil && bat.AuxData != nil {
   145  		ctr.mp = bat.DupJmAuxData()
   146  		ctr.maxAllocSize = max(ctr.maxAllocSize, ctr.mp.Size())
   147  	}
   148  	return nil
   149  }
   150  
   151  func (ctr *container) receiveBatch(anal process.Analyze) error {
   152  	for {
   153  		bat, _, err := ctr.ReceiveFromSingleReg(1, anal)
   154  		if err != nil {
   155  			return err
   156  		}
   157  		if bat != nil {
   158  			ctr.batchRowCount += bat.RowCount()
   159  			ctr.batches = append(ctr.batches, bat)
   160  		} else {
   161  			break
   162  		}
   163  	}
   164  	for i := 0; i < len(ctr.batches)-1; i++ {
   165  		if ctr.batches[i].RowCount() != colexec.DefaultBatchSize {
   166  			panic("wrong batch received for hash build!")
   167  		}
   168  	}
   169  	if ctr.batchRowCount > 0 {
   170  		ctr.matched = &bitmap.Bitmap{}
   171  		ctr.matched.InitWithSize(int64(ctr.batchRowCount))
   172  	}
   173  	return nil
   174  }
   175  
   176  func (ctr *container) build(anal process.Analyze) error {
   177  	err := ctr.receiveHashMap(anal)
   178  	if err != nil {
   179  		return err
   180  	}
   181  	return ctr.receiveBatch(anal)
   182  }
   183  
   184  func (ctr *container) sendLast(ap *Argument, proc *process.Process, analyze process.Analyze, _ bool, isLast bool) (bool, error) {
   185  	ctr.handledLast = true
   186  
   187  	if ctr.matched == nil {
   188  		return true, nil
   189  	}
   190  
   191  	if ap.NumCPU > 1 {
   192  		if !ap.IsMerger {
   193  			ap.Channel <- ctr.matched
   194  			return true, nil
   195  		}
   196  
   197  		cnt := 1
   198  		for v := range ap.Channel {
   199  			ctr.matched.Or(v)
   200  			cnt++
   201  			if cnt == int(ap.NumCPU) {
   202  				close(ap.Channel)
   203  				break
   204  			}
   205  		}
   206  	}
   207  
   208  	count := ctr.batchRowCount - ctr.matched.Count()
   209  	ctr.matched.Negate()
   210  	sels := make([]int32, 0, count)
   211  	itr := ctr.matched.Iterator()
   212  	for itr.HasNext() {
   213  		r := itr.Next()
   214  		sels = append(sels, int32(r))
   215  	}
   216  
   217  	if len(sels) <= colexec.DefaultBatchSize {
   218  		if ctr.rbat != nil {
   219  			proc.PutBatch(ctr.rbat)
   220  			ctr.rbat = nil
   221  		}
   222  		ctr.rbat = batch.NewWithSize(len(ap.Result))
   223  
   224  		for i, pos := range ap.Result {
   225  			ctr.rbat.Vecs[i] = proc.GetVector(ap.RightTypes[pos])
   226  		}
   227  		for j, pos := range ap.Result {
   228  			for _, sel := range sels {
   229  				idx1, idx2 := sel/colexec.DefaultBatchSize, sel%colexec.DefaultBatchSize
   230  				if err := ctr.rbat.Vecs[j].UnionOne(ctr.batches[idx1].Vecs[pos], int64(idx2), proc.Mp()); err != nil {
   231  					return false, err
   232  				}
   233  			}
   234  		}
   235  		ctr.rbat.AddRowCount(len(sels))
   236  		analyze.Output(ctr.rbat, isLast)
   237  		ap.rbat = []*batch.Batch{ctr.rbat}
   238  		return false, nil
   239  	} else {
   240  		n := (len(sels)-1)/colexec.DefaultBatchSize + 1
   241  		ap.rbat = make([]*batch.Batch, n)
   242  		for k := range ap.rbat {
   243  			ap.rbat[k] = batch.NewWithSize(len(ap.Result))
   244  			for i, pos := range ap.Result {
   245  				ap.rbat[k].Vecs[i] = proc.GetVector(ap.RightTypes[pos])
   246  			}
   247  			var newsels []int32
   248  			if (k+1)*colexec.DefaultBatchSize <= len(sels) {
   249  				newsels = sels[k*colexec.DefaultBatchSize : (k+1)*colexec.DefaultBatchSize]
   250  			} else {
   251  				newsels = sels[k*colexec.DefaultBatchSize:]
   252  			}
   253  			for j, pos := range ap.Result {
   254  				for _, sel := range newsels {
   255  					idx1, idx2 := sel/colexec.DefaultBatchSize, sel%colexec.DefaultBatchSize
   256  					if err := ap.rbat[k].Vecs[j].UnionOne(ctr.batches[idx1].Vecs[pos], int64(idx2), proc.Mp()); err != nil {
   257  						return false, err
   258  					}
   259  				}
   260  			}
   261  			ap.rbat[k].SetRowCount(len(newsels))
   262  			analyze.Output(ap.rbat[k], isLast)
   263  		}
   264  		return false, nil
   265  	}
   266  
   267  }
   268  
   269  func (ctr *container) probe(bat *batch.Batch, ap *Argument, proc *process.Process, analyze process.Analyze, isFirst bool, _ bool) error {
   270  	defer proc.PutBatch(bat)
   271  	analyze.Input(bat, isFirst)
   272  
   273  	if err := ctr.evalJoinCondition(bat, proc); err != nil {
   274  		return err
   275  	}
   276  	if ctr.joinBat1 == nil {
   277  		ctr.joinBat1, ctr.cfs1 = colexec.NewJoinBatch(bat, proc.Mp())
   278  	}
   279  	if ctr.joinBat2 == nil {
   280  		ctr.joinBat2, ctr.cfs2 = colexec.NewJoinBatch(ctr.batches[0], proc.Mp())
   281  	}
   282  	count := bat.RowCount()
   283  	mSels := ctr.mp.Sels()
   284  	itr := ctr.mp.NewIterator()
   285  	for i := 0; i < count; i += hashmap.UnitLimit {
   286  		n := count - i
   287  		if n > hashmap.UnitLimit {
   288  			n = hashmap.UnitLimit
   289  		}
   290  		copy(ctr.inBuckets, hashmap.OneUInt8s)
   291  		vals, zvals := itr.Find(i, n, ctr.vecs, ctr.inBuckets)
   292  		for k := 0; k < n; k++ {
   293  			if ctr.inBuckets[k] == 0 || zvals[k] == 0 || vals[k] == 0 {
   294  				continue
   295  			}
   296  			if ap.HashOnPK {
   297  				idx1, idx2 := int64(vals[k]-1)/colexec.DefaultBatchSize, int64(vals[k]-1)%colexec.DefaultBatchSize
   298  				if ctr.matched.Contains(vals[k] - 1) {
   299  					continue
   300  				}
   301  				if ap.Cond != nil {
   302  					if err := colexec.SetJoinBatchValues(ctr.joinBat1, bat, int64(i+k),
   303  						1, ctr.cfs1); err != nil {
   304  						return err
   305  					}
   306  					if err := colexec.SetJoinBatchValues(ctr.joinBat2, ctr.batches[idx1], idx2,
   307  						1, ctr.cfs2); err != nil {
   308  						return err
   309  					}
   310  					ctr.tmpBatches[0] = ctr.joinBat1
   311  					ctr.tmpBatches[1] = ctr.joinBat2
   312  					vec, err := ctr.expr.Eval(proc, ctr.tmpBatches)
   313  					if err != nil {
   314  						return err
   315  					}
   316  					if vec.IsConstNull() || vec.GetNulls().Contains(0) {
   317  						continue
   318  					} else {
   319  						vcol := vector.MustFixedCol[bool](vec)
   320  						if !vcol[0] {
   321  							continue
   322  						}
   323  					}
   324  				}
   325  				ctr.matched.Add(vals[k] - 1)
   326  			} else {
   327  				sels := mSels[vals[k]-1]
   328  				for _, sel := range sels {
   329  					if ctr.matched.Contains(uint64(sel)) {
   330  						continue
   331  					}
   332  					idx1, idx2 := sel/colexec.DefaultBatchSize, sel%colexec.DefaultBatchSize
   333  					if ap.Cond != nil {
   334  						if err := colexec.SetJoinBatchValues(ctr.joinBat1, bat, int64(i+k),
   335  							1, ctr.cfs1); err != nil {
   336  							return err
   337  						}
   338  						if err := colexec.SetJoinBatchValues(ctr.joinBat2, ctr.batches[idx1], int64(idx2),
   339  							1, ctr.cfs2); err != nil {
   340  							return err
   341  						}
   342  						ctr.tmpBatches[0] = ctr.joinBat1
   343  						ctr.tmpBatches[1] = ctr.joinBat2
   344  						vec, err := ctr.expr.Eval(proc, ctr.tmpBatches)
   345  						if err != nil {
   346  							return err
   347  						}
   348  						if vec.IsConstNull() || vec.GetNulls().Contains(0) {
   349  							continue
   350  						} else {
   351  							vcol := vector.MustFixedCol[bool](vec)
   352  							if !vcol[0] {
   353  								continue
   354  							}
   355  						}
   356  					}
   357  					ctr.matched.Add(uint64(sel))
   358  				}
   359  			}
   360  
   361  		}
   362  	}
   363  
   364  	return nil
   365  }
   366  
   367  func (ctr *container) evalJoinCondition(bat *batch.Batch, proc *process.Process) error {
   368  	for i := range ctr.evecs {
   369  		vec, err := ctr.evecs[i].executor.Eval(proc, []*batch.Batch{bat})
   370  		if err != nil {
   371  			return err
   372  		}
   373  		ctr.vecs[i] = vec
   374  		ctr.evecs[i].vec = vec
   375  	}
   376  	return nil
   377  }