github.com/matrixorigin/matrixone@v1.2.0/pkg/sql/colexec/semi/join.go (about)

     1  // Copyright 2021 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package semi
    16  
    17  import (
    18  	"bytes"
    19  
    20  	"github.com/matrixorigin/matrixone/pkg/common/hashmap"
    21  	"github.com/matrixorigin/matrixone/pkg/container/batch"
    22  	"github.com/matrixorigin/matrixone/pkg/container/vector"
    23  	"github.com/matrixorigin/matrixone/pkg/sql/colexec"
    24  	"github.com/matrixorigin/matrixone/pkg/vm"
    25  	"github.com/matrixorigin/matrixone/pkg/vm/process"
    26  )
    27  
    28  const argName = "semi"
    29  
    30  func (arg *Argument) String(buf *bytes.Buffer) {
    31  	buf.WriteString(argName)
    32  	buf.WriteString(": semi join ")
    33  }
    34  
    35  func (arg *Argument) Prepare(proc *process.Process) (err error) {
    36  	ap := arg
    37  	ap.ctr = new(container)
    38  	ap.ctr.InitReceiver(proc, false)
    39  	ap.ctr.inBuckets = make([]uint8, hashmap.UnitLimit)
    40  	ap.ctr.vecs = make([]*vector.Vector, len(ap.Conditions[0]))
    41  
    42  	ap.ctr.evecs = make([]evalVector, len(ap.Conditions[0]))
    43  	for i := range ap.ctr.evecs {
    44  		ap.ctr.evecs[i].executor, err = colexec.NewExpressionExecutor(proc, ap.Conditions[0][i])
    45  		if err != nil {
    46  			return err
    47  		}
    48  	}
    49  
    50  	if ap.Cond != nil {
    51  		ap.ctr.expr, err = colexec.NewExpressionExecutor(proc, ap.Cond)
    52  	}
    53  	return err
    54  }
    55  
    56  func (arg *Argument) Call(proc *process.Process) (vm.CallResult, error) {
    57  	if err, isCancel := vm.CancelCheck(proc); isCancel {
    58  		return vm.CancelResult, err
    59  	}
    60  
    61  	anal := proc.GetAnalyze(arg.GetIdx(), arg.GetParallelIdx(), arg.GetParallelMajor())
    62  	anal.Start()
    63  	defer anal.Stop()
    64  	ap := arg
    65  	ctr := ap.ctr
    66  	result := vm.NewCallResult()
    67  	for {
    68  		switch ctr.state {
    69  		case Build:
    70  			if err := ctr.build(anal); err != nil {
    71  				return result, err
    72  			}
    73  			if ctr.mp == nil && !arg.IsShuffle {
    74  				// for inner ,right and semi join, if hashmap is empty, we can finish this pipeline
    75  				// shuffle join can't stop early for this moment
    76  				ctr.state = End
    77  			} else {
    78  				ctr.state = Probe
    79  			}
    80  			if ctr.mp != nil && ctr.mp.PushedRuntimeFilterIn() && ap.Cond == nil {
    81  				ctr.skipProbe = true
    82  			}
    83  
    84  		case Probe:
    85  			bat, _, err := ctr.ReceiveFromSingleReg(0, anal)
    86  			if err != nil {
    87  				return result, err
    88  			}
    89  
    90  			if bat == nil {
    91  				ctr.state = End
    92  				continue
    93  			}
    94  			if bat.IsEmpty() {
    95  				proc.PutBatch(bat)
    96  				continue
    97  			}
    98  			if ctr.skipProbe {
    99  				vecused := make([]bool, len(bat.Vecs))
   100  				newvecs := make([]*vector.Vector, len(ap.Result))
   101  				for i, pos := range ap.Result {
   102  					vecused[pos] = true
   103  					newvecs[i] = bat.Vecs[pos]
   104  				}
   105  				for i := range bat.Vecs {
   106  					if !vecused[i] {
   107  						bat.Vecs[i].Free(proc.Mp())
   108  					}
   109  				}
   110  				bat.Vecs = newvecs
   111  				result.Batch = bat
   112  				anal.Output(bat, arg.GetIsLast())
   113  				return result, nil
   114  			}
   115  			if ctr.mp == nil {
   116  				proc.PutBatch(bat)
   117  				continue
   118  			}
   119  			if err := ctr.probe(bat, ap, proc, anal, arg.GetIsFirst(), arg.GetIsLast(), &result); err != nil {
   120  				bat.Clean(proc.Mp())
   121  				return result, err
   122  			}
   123  			proc.PutBatch(bat)
   124  			return result, nil
   125  
   126  		default:
   127  			result.Batch = nil
   128  			result.Status = vm.ExecStop
   129  			return result, nil
   130  		}
   131  	}
   132  }
   133  
   134  func (ctr *container) receiveHashMap(anal process.Analyze) error {
   135  	bat, _, err := ctr.ReceiveFromSingleReg(1, anal)
   136  	if err != nil {
   137  		return err
   138  	}
   139  	if bat != nil && bat.AuxData != nil {
   140  		ctr.mp = bat.DupJmAuxData()
   141  		ctr.maxAllocSize = max(ctr.maxAllocSize, ctr.mp.Size())
   142  	}
   143  	return nil
   144  }
   145  
   146  func (ctr *container) receiveBatch(anal process.Analyze) error {
   147  	for {
   148  		bat, _, err := ctr.ReceiveFromSingleReg(1, anal)
   149  		if err != nil {
   150  			return err
   151  		}
   152  		if bat != nil {
   153  			ctr.batchRowCount += bat.RowCount()
   154  			ctr.batches = append(ctr.batches, bat)
   155  		} else {
   156  			break
   157  		}
   158  	}
   159  	for i := 0; i < len(ctr.batches)-1; i++ {
   160  		if ctr.batches[i].RowCount() != colexec.DefaultBatchSize {
   161  			panic("wrong batch received for hash build!")
   162  		}
   163  	}
   164  	return nil
   165  }
   166  
   167  func (ctr *container) build(anal process.Analyze) error {
   168  	err := ctr.receiveHashMap(anal)
   169  	if err != nil {
   170  		return err
   171  	}
   172  	return ctr.receiveBatch(anal)
   173  }
   174  
   175  func (ctr *container) probe(bat *batch.Batch, ap *Argument, proc *process.Process, anal process.Analyze, isFirst bool, isLast bool, result *vm.CallResult) error {
   176  	anal.Input(bat, isFirst)
   177  	if ctr.rbat != nil {
   178  		proc.PutBatch(ctr.rbat)
   179  		ctr.rbat = nil
   180  	}
   181  	ctr.rbat = batch.NewWithSize(len(ap.Result))
   182  	for i, pos := range ap.Result {
   183  		ctr.rbat.Vecs[i] = proc.GetVector(*bat.Vecs[pos].GetType())
   184  		// for semi join, if left batch is sorted , then output batch is sorted
   185  		ctr.rbat.Vecs[i].SetSorted(bat.Vecs[pos].GetSorted())
   186  	}
   187  	if err := ctr.evalJoinCondition(bat, proc); err != nil {
   188  		return err
   189  	}
   190  	if ctr.joinBat1 == nil {
   191  		ctr.joinBat1, ctr.cfs1 = colexec.NewJoinBatch(bat, proc.Mp())
   192  	}
   193  	if ctr.joinBat2 == nil && ctr.batchRowCount > 0 {
   194  		ctr.joinBat2, ctr.cfs2 = colexec.NewJoinBatch(ctr.batches[0], proc.Mp())
   195  	}
   196  	count := bat.RowCount()
   197  	mSels := ctr.mp.Sels()
   198  	itr := ctr.mp.NewIterator()
   199  
   200  	rowCountIncrease := 0
   201  	eligible := make([]int32, 0, hashmap.UnitLimit)
   202  	for i := 0; i < count; i += hashmap.UnitLimit {
   203  		n := count - i
   204  		if n > hashmap.UnitLimit {
   205  			n = hashmap.UnitLimit
   206  		}
   207  		copy(ctr.inBuckets, hashmap.OneUInt8s)
   208  		vals, zvals := itr.Find(i, n, ctr.vecs, ctr.inBuckets)
   209  		for k := 0; k < n; k++ {
   210  			if ctr.inBuckets[k] == 0 || zvals[k] == 0 || vals[k] == 0 {
   211  				continue
   212  			}
   213  			if ap.Cond != nil {
   214  				matched := false // mark if any tuple satisfies the condition
   215  				if ap.HashOnPK {
   216  					idx1, idx2 := int64(vals[k]-1)/colexec.DefaultBatchSize, int64(vals[k]-1)%colexec.DefaultBatchSize
   217  					if err := colexec.SetJoinBatchValues(ctr.joinBat1, bat, int64(i+k),
   218  						1, ctr.cfs1); err != nil {
   219  						return err
   220  					}
   221  					if err := colexec.SetJoinBatchValues(ctr.joinBat2, ctr.batches[idx1], idx2,
   222  						1, ctr.cfs2); err != nil {
   223  						return err
   224  					}
   225  					vec, err := ctr.expr.Eval(proc, []*batch.Batch{ctr.joinBat1, ctr.joinBat2})
   226  					if err != nil {
   227  						return err
   228  					}
   229  					if vec.IsConstNull() || vec.GetNulls().Contains(0) {
   230  						continue
   231  					}
   232  					bs := vector.MustFixedCol[bool](vec)
   233  					if bs[0] {
   234  						matched = true
   235  					}
   236  				} else {
   237  					sels := mSels[vals[k]-1]
   238  					for _, sel := range sels {
   239  						idx1, idx2 := sel/colexec.DefaultBatchSize, sel%colexec.DefaultBatchSize
   240  						if err := colexec.SetJoinBatchValues(ctr.joinBat1, bat, int64(i+k),
   241  							1, ctr.cfs1); err != nil {
   242  							return err
   243  						}
   244  						if err := colexec.SetJoinBatchValues(ctr.joinBat2, ctr.batches[idx1], int64(idx2),
   245  							1, ctr.cfs2); err != nil {
   246  							return err
   247  						}
   248  						vec, err := ctr.expr.Eval(proc, []*batch.Batch{ctr.joinBat1, ctr.joinBat2})
   249  						if err != nil {
   250  							return err
   251  						}
   252  						if vec.IsConstNull() || vec.GetNulls().Contains(0) {
   253  							continue
   254  						}
   255  						bs := vector.MustFixedCol[bool](vec)
   256  						if bs[0] {
   257  							matched = true
   258  							break
   259  						}
   260  					}
   261  
   262  				}
   263  				if !matched {
   264  					continue
   265  				}
   266  			}
   267  			eligible = append(eligible, int32(i+k))
   268  			rowCountIncrease++
   269  		}
   270  		for j, pos := range ap.Result {
   271  			if err := ctr.rbat.Vecs[j].Union(bat.Vecs[pos], eligible, proc.Mp()); err != nil {
   272  				return err
   273  			}
   274  		}
   275  		eligible = eligible[:0]
   276  	}
   277  
   278  	ctr.rbat.AddRowCount(rowCountIncrease)
   279  	anal.Output(ctr.rbat, isLast)
   280  	result.Batch = ctr.rbat
   281  	return nil
   282  }
   283  
   284  func (ctr *container) evalJoinCondition(bat *batch.Batch, proc *process.Process) error {
   285  	for i := range ctr.evecs {
   286  		vec, err := ctr.evecs[i].executor.Eval(proc, []*batch.Batch{bat})
   287  		if err != nil {
   288  			return err
   289  		}
   290  		ctr.vecs[i] = vec
   291  		ctr.evecs[i].vec = vec
   292  	}
   293  	return nil
   294  }