github.com/matrixorigin/matrixone@v1.2.0/pkg/sql/colexec/join/join.go (about)

     1  // Copyright 2021 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package join
    16  
    17  import (
    18  	"bytes"
    19  
    20  	"github.com/matrixorigin/matrixone/pkg/common/hashmap"
    21  	"github.com/matrixorigin/matrixone/pkg/common/moerr"
    22  	"github.com/matrixorigin/matrixone/pkg/container/batch"
    23  	"github.com/matrixorigin/matrixone/pkg/container/vector"
    24  	"github.com/matrixorigin/matrixone/pkg/sql/colexec"
    25  	"github.com/matrixorigin/matrixone/pkg/vm"
    26  	"github.com/matrixorigin/matrixone/pkg/vm/process"
    27  )
    28  
    29  const argName = "join"
    30  
    31  func (arg *Argument) String(buf *bytes.Buffer) {
    32  	buf.WriteString(argName)
    33  	buf.WriteString(": inner join ")
    34  }
    35  
    36  func (arg *Argument) Prepare(proc *process.Process) (err error) {
    37  	ap := arg
    38  	ap.ctr = new(container)
    39  	ap.ctr.InitReceiver(proc, false)
    40  	ap.ctr.inBuckets = make([]uint8, hashmap.UnitLimit)
    41  	ap.ctr.vecs = make([]*vector.Vector, len(ap.Conditions[0]))
    42  	ap.ctr.evecs = make([]evalVector, len(ap.Conditions[0]))
    43  	for i := range ap.ctr.evecs {
    44  		ap.ctr.evecs[i].executor, err = colexec.NewExpressionExecutor(proc, ap.Conditions[0][i])
    45  		if err != nil {
    46  			return err
    47  		}
    48  	}
    49  
    50  	if ap.Cond != nil {
    51  		ap.ctr.expr, err = colexec.NewExpressionExecutor(proc, ap.Cond)
    52  	}
    53  	return err
    54  }
    55  
    56  func (arg *Argument) Call(proc *process.Process) (vm.CallResult, error) {
    57  	if err, isCancel := vm.CancelCheck(proc); isCancel {
    58  		return vm.CancelResult, err
    59  	}
    60  
    61  	anal := proc.GetAnalyze(arg.GetIdx(), arg.GetParallelIdx(), arg.GetParallelMajor())
    62  	anal.Start()
    63  	defer anal.Stop()
    64  	ap := arg
    65  	ctr := ap.ctr
    66  	result := vm.NewCallResult()
    67  	for {
    68  		switch ctr.state {
    69  		case Build:
    70  			if err := ctr.build(anal); err != nil {
    71  				return result, err
    72  			}
    73  			if ctr.mp == nil && !arg.IsShuffle {
    74  				// for inner ,right and semi join, if hashmap is empty, we can finish this pipeline
    75  				// shuffle join can't stop early for this moment
    76  				ctr.state = End
    77  			} else {
    78  				ctr.state = Probe
    79  			}
    80  		case Probe:
    81  			if ap.bat == nil {
    82  				bat, _, err := ctr.ReceiveFromSingleReg(0, anal)
    83  				if err != nil {
    84  					return result, err
    85  				}
    86  				if bat == nil {
    87  					ctr.state = End
    88  					continue
    89  				}
    90  				if bat.Last() {
    91  					result.Batch = bat
    92  					return result, nil
    93  				}
    94  				if bat.IsEmpty() {
    95  					proc.PutBatch(bat)
    96  					continue
    97  				}
    98  				if ctr.mp == nil {
    99  					proc.PutBatch(bat)
   100  					continue
   101  				}
   102  				ap.bat = bat
   103  				ap.lastrow = 0
   104  			}
   105  
   106  			startrow := ap.lastrow
   107  			if err := ctr.probe(ap, proc, anal, arg.GetIsFirst(), arg.GetIsLast(), &result); err != nil {
   108  				return result, err
   109  			}
   110  			if ap.lastrow == 0 {
   111  				proc.PutBatch(ap.bat)
   112  				ap.bat = nil
   113  			} else if ap.lastrow == startrow {
   114  				return result, moerr.NewInternalErrorNoCtx("inner join hanging")
   115  			}
   116  			return result, nil
   117  
   118  		default:
   119  			result.Batch = nil
   120  			result.Status = vm.ExecStop
   121  			return result, nil
   122  		}
   123  	}
   124  }
   125  
   126  func (ctr *container) receiveHashMap(anal process.Analyze) error {
   127  	bat, _, err := ctr.ReceiveFromSingleReg(1, anal)
   128  	if err != nil {
   129  		return err
   130  	}
   131  	if bat != nil && bat.AuxData != nil {
   132  		ctr.mp = bat.DupJmAuxData()
   133  		ctr.maxAllocSize = max(ctr.maxAllocSize, ctr.mp.Size())
   134  	}
   135  	return nil
   136  }
   137  
   138  func (ctr *container) receiveBatch(anal process.Analyze) error {
   139  	for {
   140  		bat, _, err := ctr.ReceiveFromSingleReg(1, anal)
   141  		if err != nil {
   142  			return err
   143  		}
   144  		if bat != nil {
   145  			ctr.batchRowCount += bat.RowCount()
   146  			ctr.batches = append(ctr.batches, bat)
   147  		} else {
   148  			break
   149  		}
   150  	}
   151  	for i := 0; i < len(ctr.batches)-1; i++ {
   152  		if ctr.batches[i].RowCount() != colexec.DefaultBatchSize {
   153  			panic("wrong batch received for hash build!")
   154  		}
   155  	}
   156  	return nil
   157  }
   158  
   159  func (ctr *container) build(anal process.Analyze) error {
   160  	err := ctr.receiveHashMap(anal)
   161  	if err != nil {
   162  		return err
   163  	}
   164  	return ctr.receiveBatch(anal)
   165  }
   166  
   167  func (ctr *container) probe(ap *Argument, proc *process.Process, anal process.Analyze, isFirst bool, isLast bool, result *vm.CallResult) error {
   168  
   169  	anal.Input(ap.bat, isFirst)
   170  	if ctr.rbat != nil {
   171  		proc.PutBatch(ctr.rbat)
   172  		ctr.rbat = nil
   173  	}
   174  	ctr.rbat = batch.NewWithSize(len(ap.Result))
   175  	for i, rp := range ap.Result {
   176  		if rp.Rel == 0 {
   177  			ctr.rbat.Vecs[i] = proc.GetVector(*ap.bat.Vecs[rp.Pos].GetType())
   178  			// for inner join, if left batch is sorted , then output batch is sorted
   179  			ctr.rbat.Vecs[i].SetSorted(ap.bat.Vecs[rp.Pos].GetSorted())
   180  		} else {
   181  			ctr.rbat.Vecs[i] = proc.GetVector(*ctr.batches[0].Vecs[rp.Pos].GetType())
   182  		}
   183  	}
   184  
   185  	if err := ctr.evalJoinCondition(ap.bat, proc); err != nil {
   186  		return err
   187  	}
   188  	if ctr.joinBat1 == nil {
   189  		ctr.joinBat1, ctr.cfs1 = colexec.NewJoinBatch(ap.bat, proc.Mp())
   190  	}
   191  	if ctr.joinBat2 == nil && ctr.batchRowCount > 0 {
   192  		ctr.joinBat2, ctr.cfs2 = colexec.NewJoinBatch(ctr.batches[0], proc.Mp())
   193  	}
   194  
   195  	mSels := ctr.mp.Sels()
   196  	count := ap.bat.RowCount()
   197  	itr := ctr.mp.NewIterator()
   198  	rowCount := 0
   199  	for i := ap.lastrow; i < count; i += hashmap.UnitLimit {
   200  		if rowCount >= colexec.DefaultBatchSize {
   201  			ctr.rbat.AddRowCount(rowCount)
   202  			anal.Output(ctr.rbat, isLast)
   203  			result.Batch = ctr.rbat
   204  			ap.lastrow = i
   205  			return nil
   206  		}
   207  		n := count - i
   208  		if n > hashmap.UnitLimit {
   209  			n = hashmap.UnitLimit
   210  		}
   211  		copy(ctr.inBuckets, hashmap.OneUInt8s)
   212  
   213  		vals, zvals := itr.Find(i, n, ctr.vecs, ctr.inBuckets)
   214  		for k := 0; k < n; k++ {
   215  			if ctr.inBuckets[k] == 0 || zvals[k] == 0 || vals[k] == 0 {
   216  				continue
   217  			}
   218  			idx := vals[k] - 1
   219  
   220  			if ap.Cond == nil {
   221  				if ap.HashOnPK {
   222  					for j, rp := range ap.Result {
   223  						if rp.Rel == 0 {
   224  							if err := ctr.rbat.Vecs[j].UnionOne(ap.bat.Vecs[rp.Pos], int64(i+k), proc.Mp()); err != nil {
   225  								return err
   226  							}
   227  						} else {
   228  							idx1, idx2 := idx/colexec.DefaultBatchSize, idx%colexec.DefaultBatchSize
   229  							if err := ctr.rbat.Vecs[j].UnionOne(ctr.batches[idx1].Vecs[rp.Pos], int64(idx2), proc.Mp()); err != nil {
   230  								return err
   231  							}
   232  						}
   233  					}
   234  					rowCount++
   235  				} else {
   236  					sels := mSels[idx]
   237  					for j, rp := range ap.Result {
   238  						if rp.Rel == 0 {
   239  							if err := ctr.rbat.Vecs[j].UnionMulti(ap.bat.Vecs[rp.Pos], int64(i+k), len(sels), proc.Mp()); err != nil {
   240  								return err
   241  							}
   242  						} else {
   243  							for _, sel := range sels {
   244  								idx1, idx2 := sel/colexec.DefaultBatchSize, sel%colexec.DefaultBatchSize
   245  								if err := ctr.rbat.Vecs[j].UnionOne(ctr.batches[idx1].Vecs[rp.Pos], int64(idx2), proc.Mp()); err != nil {
   246  									return err
   247  								}
   248  							}
   249  						}
   250  					}
   251  					rowCount += len(sels)
   252  				}
   253  			} else {
   254  				if ap.HashOnPK {
   255  					if err := ctr.evalApCondForOneSel(ap.bat, ctr.rbat, ap, proc, int64(i+k), int64(idx)); err != nil {
   256  						return err
   257  					}
   258  					rowCount++
   259  				} else {
   260  					sels := mSels[idx]
   261  					for _, sel := range sels {
   262  						if err := ctr.evalApCondForOneSel(ap.bat, ctr.rbat, ap, proc, int64(i+k), int64(sel)); err != nil {
   263  							return err
   264  						}
   265  					}
   266  					rowCount += len(sels)
   267  				}
   268  			}
   269  		}
   270  	}
   271  
   272  	ctr.rbat.AddRowCount(rowCount)
   273  	anal.Output(ctr.rbat, isLast)
   274  	result.Batch = ctr.rbat
   275  	ap.lastrow = 0
   276  	return nil
   277  }
   278  
   279  func (ctr *container) evalApCondForOneSel(bat, rbat *batch.Batch, ap *Argument, proc *process.Process, row, sel int64) error {
   280  	if err := colexec.SetJoinBatchValues(ctr.joinBat1, bat, row,
   281  		1, ctr.cfs1); err != nil {
   282  		return err
   283  	}
   284  	idx1, idx2 := sel/colexec.DefaultBatchSize, sel%colexec.DefaultBatchSize
   285  	if err := colexec.SetJoinBatchValues(ctr.joinBat2, ctr.batches[idx1], idx2,
   286  		1, ctr.cfs2); err != nil {
   287  		return err
   288  	}
   289  	vec, err := ctr.expr.Eval(proc, []*batch.Batch{ctr.joinBat1, ctr.joinBat2})
   290  	if err != nil {
   291  		rbat.Clean(proc.Mp())
   292  		return err
   293  	}
   294  	if vec.IsConstNull() || vec.GetNulls().Contains(0) {
   295  		return nil
   296  	}
   297  	bs := vector.MustFixedCol[bool](vec)
   298  	if !bs[0] {
   299  		return nil
   300  	}
   301  	for j, rp := range ap.Result {
   302  		if rp.Rel == 0 {
   303  			if err := rbat.Vecs[j].UnionOne(bat.Vecs[rp.Pos], row, proc.Mp()); err != nil {
   304  				rbat.Clean(proc.Mp())
   305  				return err
   306  			}
   307  		} else {
   308  			if err := rbat.Vecs[j].UnionOne(ctr.batches[idx1].Vecs[rp.Pos], idx2, proc.Mp()); err != nil {
   309  				rbat.Clean(proc.Mp())
   310  				return err
   311  			}
   312  		}
   313  	}
   314  	return nil
   315  }
   316  
   317  func (ctr *container) evalJoinCondition(bat *batch.Batch, proc *process.Process) error {
   318  	for i := range ctr.evecs {
   319  		vec, err := ctr.evecs[i].executor.Eval(proc, []*batch.Batch{bat})
   320  		if err != nil {
   321  			return err
   322  		}
   323  		ctr.vecs[i] = vec
   324  		ctr.evecs[i].vec = vec
   325  	}
   326  	return nil
   327  }