github.com/matrixorigin/matrixone@v1.2.0/pkg/sql/colexec/left/join.go (about)

     1  // Copyright 2021 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package left
    16  
    17  import (
    18  	"bytes"
    19  
    20  	"github.com/matrixorigin/matrixone/pkg/common/hashmap"
    21  	"github.com/matrixorigin/matrixone/pkg/common/moerr"
    22  	"github.com/matrixorigin/matrixone/pkg/container/batch"
    23  	"github.com/matrixorigin/matrixone/pkg/container/vector"
    24  	"github.com/matrixorigin/matrixone/pkg/sql/colexec"
    25  	"github.com/matrixorigin/matrixone/pkg/vm"
    26  	"github.com/matrixorigin/matrixone/pkg/vm/process"
    27  )
    28  
    29  const argName = "left"
    30  
    31  func (arg *Argument) String(buf *bytes.Buffer) {
    32  	buf.WriteString(argName)
    33  	buf.WriteString(": left join ")
    34  }
    35  
    36  func (arg *Argument) Prepare(proc *process.Process) (err error) {
    37  	ap := arg
    38  	ap.ctr = new(container)
    39  	ap.ctr.InitReceiver(proc, false)
    40  	ap.ctr.inBuckets = make([]uint8, hashmap.UnitLimit)
    41  	ap.ctr.vecs = make([]*vector.Vector, len(ap.Conditions[0]))
    42  
    43  	ap.ctr.evecs = make([]evalVector, len(ap.Conditions[0]))
    44  	for i := range ap.ctr.evecs {
    45  		ap.ctr.evecs[i].executor, err = colexec.NewExpressionExecutor(proc, ap.Conditions[0][i])
    46  	}
    47  
    48  	if ap.Cond != nil {
    49  		ap.ctr.expr, err = colexec.NewExpressionExecutor(proc, ap.Cond)
    50  	}
    51  	return err
    52  }
    53  
    54  func (arg *Argument) Call(proc *process.Process) (vm.CallResult, error) {
    55  	if err, isCancel := vm.CancelCheck(proc); isCancel {
    56  		return vm.CancelResult, err
    57  	}
    58  
    59  	anal := proc.GetAnalyze(arg.GetIdx(), arg.GetParallelIdx(), arg.GetParallelMajor())
    60  	anal.Start()
    61  	defer anal.Stop()
    62  	ap := arg
    63  	ctr := ap.ctr
    64  	result := vm.NewCallResult()
    65  	for {
    66  		switch ctr.state {
    67  		case Build:
    68  			if err := ctr.build(anal); err != nil {
    69  				return result, err
    70  			}
    71  			ctr.state = Probe
    72  
    73  		case Probe:
    74  			if ap.bat == nil {
    75  				bat, _, err := ctr.ReceiveFromSingleReg(0, anal)
    76  				if err != nil {
    77  					return result, err
    78  				}
    79  
    80  				if bat == nil {
    81  					ctr.state = End
    82  					continue
    83  				}
    84  				if bat.IsEmpty() {
    85  					proc.PutBatch(bat)
    86  					continue
    87  				}
    88  				ap.bat = bat
    89  				ap.lastrow = 0
    90  			}
    91  
    92  			startrow := ap.lastrow
    93  			if ctr.mp == nil {
    94  				if err := ctr.emptyProbe(ap, proc, anal, arg.GetIsFirst(), arg.GetIsLast(), &result); err != nil {
    95  					return result, err
    96  				}
    97  			} else {
    98  				if err := ctr.probe(ap, proc, anal, arg.GetIsFirst(), arg.GetIsLast(), &result); err != nil {
    99  					return result, err
   100  				}
   101  			}
   102  			if ap.lastrow == 0 {
   103  				proc.PutBatch(ap.bat)
   104  				ap.bat = nil
   105  			} else if ap.lastrow == startrow {
   106  				return result, moerr.NewInternalErrorNoCtx("left join hanging")
   107  			}
   108  			return result, nil
   109  
   110  		default:
   111  			result.Batch = nil
   112  			result.Status = vm.ExecStop
   113  			return result, nil
   114  		}
   115  	}
   116  }
   117  
   118  func (ctr *container) receiveHashMap(anal process.Analyze) error {
   119  	bat, _, err := ctr.ReceiveFromSingleReg(1, anal)
   120  	if err != nil {
   121  		return err
   122  	}
   123  	if bat != nil && bat.AuxData != nil {
   124  		ctr.mp = bat.DupJmAuxData()
   125  		ctr.maxAllocSize = max(ctr.maxAllocSize, ctr.mp.Size())
   126  	}
   127  	return nil
   128  }
   129  
   130  func (ctr *container) receiveBatch(anal process.Analyze) error {
   131  	for {
   132  		bat, _, err := ctr.ReceiveFromSingleReg(1, anal)
   133  		if err != nil {
   134  			return err
   135  		}
   136  		if bat != nil {
   137  			ctr.batchRowCount += bat.RowCount()
   138  			ctr.batches = append(ctr.batches, bat)
   139  		} else {
   140  			break
   141  		}
   142  	}
   143  	for i := 0; i < len(ctr.batches)-1; i++ {
   144  		if ctr.batches[i].RowCount() != colexec.DefaultBatchSize {
   145  			panic("wrong batch received for hash build!")
   146  		}
   147  	}
   148  	return nil
   149  }
   150  
   151  func (ctr *container) build(anal process.Analyze) error {
   152  	err := ctr.receiveHashMap(anal)
   153  	if err != nil {
   154  		return err
   155  	}
   156  	return ctr.receiveBatch(anal)
   157  }
   158  
   159  func (ctr *container) emptyProbe(ap *Argument, proc *process.Process, anal process.Analyze, isFirst bool, isLast bool, result *vm.CallResult) error {
   160  	anal.Input(ap.bat, isFirst)
   161  	if ctr.rbat != nil {
   162  		proc.PutBatch(ctr.rbat)
   163  		ctr.rbat = nil
   164  	}
   165  	ctr.rbat = batch.NewWithSize(len(ap.Result))
   166  	//count := bat.RowCount()
   167  	for i, rp := range ap.Result {
   168  		if rp.Rel == 0 {
   169  			// rbat.Vecs[i] = bat.Vecs[rp.Pos]
   170  			// bat.Vecs[rp.Pos] = nil
   171  			typ := *ap.bat.Vecs[rp.Pos].GetType()
   172  			ctr.rbat.Vecs[i] = proc.GetVector(typ)
   173  			if err := vector.GetUnionAllFunction(typ, proc.Mp())(ctr.rbat.Vecs[i], ap.bat.Vecs[rp.Pos]); err != nil {
   174  				return err
   175  			}
   176  			// for left join, if left batch is sorted , then output batch is sorted
   177  			ctr.rbat.Vecs[i].SetSorted(ap.bat.Vecs[rp.Pos].GetSorted())
   178  		} else {
   179  			ctr.rbat.Vecs[i] = vector.NewConstNull(ap.Typs[rp.Pos], ap.bat.RowCount(), proc.Mp())
   180  		}
   181  	}
   182  	ctr.rbat.AddRowCount(ap.bat.RowCount())
   183  	anal.Output(ctr.rbat, isLast)
   184  	result.Batch = ctr.rbat
   185  	ap.lastrow = 0
   186  	return nil
   187  }
   188  
   189  func (ctr *container) probe(ap *Argument, proc *process.Process, anal process.Analyze, isFirst bool, isLast bool, result *vm.CallResult) error {
   190  	anal.Input(ap.bat, isFirst)
   191  	if ctr.rbat != nil {
   192  		proc.PutBatch(ctr.rbat)
   193  		ctr.rbat = nil
   194  	}
   195  	ctr.rbat = batch.NewWithSize(len(ap.Result))
   196  	for i, rp := range ap.Result {
   197  		if rp.Rel == 0 {
   198  			ctr.rbat.Vecs[i] = proc.GetVector(*ap.bat.Vecs[rp.Pos].GetType())
   199  			// for left join, if left batch is sorted , then output batch is sorted
   200  			ctr.rbat.Vecs[i].SetSorted(ap.bat.Vecs[rp.Pos].GetSorted())
   201  		} else {
   202  			ctr.rbat.Vecs[i] = proc.GetVector(ap.Typs[rp.Pos])
   203  		}
   204  	}
   205  
   206  	if err := ctr.evalJoinCondition(ap.bat, proc); err != nil {
   207  		return err
   208  	}
   209  
   210  	if ctr.joinBat1 == nil {
   211  		ctr.joinBat1, ctr.cfs1 = colexec.NewJoinBatch(ap.bat, proc.Mp())
   212  	}
   213  	if ctr.joinBat2 == nil && ctr.batchRowCount > 0 {
   214  		ctr.joinBat2, ctr.cfs2 = colexec.NewJoinBatch(ctr.batches[0], proc.Mp())
   215  	}
   216  
   217  	count := ap.bat.RowCount()
   218  	mSels := ctr.mp.Sels()
   219  	itr := ctr.mp.NewIterator()
   220  	for i := ap.lastrow; i < count; i += hashmap.UnitLimit {
   221  		if ctr.rbat.RowCount() >= colexec.DefaultBatchSize {
   222  			anal.Output(ctr.rbat, isLast)
   223  			result.Batch = ctr.rbat
   224  			ap.lastrow = i
   225  			return nil
   226  		}
   227  		n := count - i
   228  		if n > hashmap.UnitLimit {
   229  			n = hashmap.UnitLimit
   230  		}
   231  		copy(ctr.inBuckets, hashmap.OneUInt8s)
   232  		vals, zvals := itr.Find(i, n, ctr.vecs, ctr.inBuckets)
   233  		rowCount := 0
   234  		for k := 0; k < n; k++ {
   235  			if ctr.inBuckets[k] == 0 {
   236  				continue
   237  			}
   238  
   239  			// if null or not found.
   240  			// the result row was  [left cols, null cols]
   241  			if zvals[k] == 0 || vals[k] == 0 {
   242  				for j, rp := range ap.Result {
   243  					// columns from the left table.
   244  					if rp.Rel == 0 {
   245  						if err := ctr.rbat.Vecs[j].UnionOne(ap.bat.Vecs[rp.Pos], int64(i+k), proc.Mp()); err != nil {
   246  							return err
   247  						}
   248  					} else {
   249  						if err := ctr.rbat.Vecs[j].UnionNull(proc.Mp()); err != nil {
   250  							return err
   251  						}
   252  					}
   253  				}
   254  				rowCount++
   255  				continue
   256  			}
   257  
   258  			matched := false
   259  			if ap.HashOnPK {
   260  				idx1, idx2 := int64(vals[k]-1)/colexec.DefaultBatchSize, int64(vals[k]-1)%colexec.DefaultBatchSize
   261  				if ap.Cond != nil {
   262  					if err := colexec.SetJoinBatchValues(ctr.joinBat1, ap.bat, int64(i+k),
   263  						1, ctr.cfs1); err != nil {
   264  						return err
   265  					}
   266  					if err := colexec.SetJoinBatchValues(ctr.joinBat2, ctr.batches[idx1], idx2,
   267  						1, ctr.cfs2); err != nil {
   268  						return err
   269  					}
   270  					vec, err := ctr.expr.Eval(proc, []*batch.Batch{ctr.joinBat1, ctr.joinBat2})
   271  					if err != nil {
   272  						return err
   273  					}
   274  					if vec.IsConstNull() || vec.GetNulls().Contains(0) {
   275  						continue
   276  					}
   277  					bs := vector.MustFixedCol[bool](vec)
   278  					if bs[0] {
   279  						matched = true
   280  						for j, rp := range ap.Result {
   281  							if rp.Rel == 0 {
   282  								if err := ctr.rbat.Vecs[j].UnionOne(ap.bat.Vecs[rp.Pos], int64(i+k), proc.Mp()); err != nil {
   283  									return err
   284  								}
   285  							} else {
   286  								if err := ctr.rbat.Vecs[j].UnionOne(ctr.batches[idx1].Vecs[rp.Pos], idx2, proc.Mp()); err != nil {
   287  									return err
   288  								}
   289  							}
   290  						}
   291  						rowCount++
   292  					}
   293  				} else {
   294  					matched = true
   295  					for j, rp := range ap.Result {
   296  						if rp.Rel == 0 {
   297  							if err := ctr.rbat.Vecs[j].UnionMulti(ap.bat.Vecs[rp.Pos], int64(i+k), 1, proc.Mp()); err != nil {
   298  								return err
   299  							}
   300  						} else {
   301  							if err := ctr.rbat.Vecs[j].UnionOne(ctr.batches[idx1].Vecs[rp.Pos], idx2, proc.Mp()); err != nil {
   302  								return err
   303  							}
   304  						}
   305  					}
   306  					rowCount++
   307  				}
   308  			} else {
   309  				sels := mSels[vals[k]-1]
   310  				if ap.Cond != nil {
   311  					if err := colexec.SetJoinBatchValues(ctr.joinBat1, ap.bat, int64(i+k),
   312  						1, ctr.cfs1); err != nil {
   313  						return err
   314  					}
   315  
   316  					for _, sel := range sels {
   317  						idx1, idx2 := sel/colexec.DefaultBatchSize, sel%colexec.DefaultBatchSize
   318  						if err := colexec.SetJoinBatchValues(ctr.joinBat2, ctr.batches[idx1], int64(idx2),
   319  							1, ctr.cfs2); err != nil {
   320  							return err
   321  						}
   322  						vec, err := ctr.expr.Eval(proc, []*batch.Batch{ctr.joinBat1, ctr.joinBat2})
   323  						if err != nil {
   324  							return err
   325  						}
   326  						if vec.IsConstNull() || vec.GetNulls().Contains(0) {
   327  							continue
   328  						}
   329  						bs := vector.MustFixedCol[bool](vec)
   330  						if !bs[0] {
   331  							continue
   332  						}
   333  						matched = true
   334  						for j, rp := range ap.Result {
   335  							if rp.Rel == 0 {
   336  								if err := ctr.rbat.Vecs[j].UnionOne(ap.bat.Vecs[rp.Pos], int64(i+k), proc.Mp()); err != nil {
   337  									return err
   338  								}
   339  							} else {
   340  								if err := ctr.rbat.Vecs[j].UnionOne(ctr.batches[idx1].Vecs[rp.Pos], int64(idx2), proc.Mp()); err != nil {
   341  									return err
   342  								}
   343  							}
   344  						}
   345  						rowCount++
   346  					}
   347  				} else {
   348  					matched = true
   349  					for j, rp := range ap.Result {
   350  						if rp.Rel == 0 {
   351  							if err := ctr.rbat.Vecs[j].UnionMulti(ap.bat.Vecs[rp.Pos], int64(i+k), len(sels), proc.Mp()); err != nil {
   352  								return err
   353  							}
   354  						} else {
   355  							for _, sel := range sels {
   356  								idx1, idx2 := sel/colexec.DefaultBatchSize, sel%colexec.DefaultBatchSize
   357  								if err := ctr.rbat.Vecs[j].UnionOne(ctr.batches[idx1].Vecs[rp.Pos], int64(idx2), proc.Mp()); err != nil {
   358  									return err
   359  								}
   360  							}
   361  						}
   362  					}
   363  					rowCount += len(sels)
   364  				}
   365  			}
   366  
   367  			if !matched {
   368  				for j, rp := range ap.Result {
   369  					if rp.Rel == 0 {
   370  						if err := ctr.rbat.Vecs[j].UnionOne(ap.bat.Vecs[rp.Pos], int64(i+k), proc.Mp()); err != nil {
   371  							return err
   372  						}
   373  					} else {
   374  						if err := ctr.rbat.Vecs[j].UnionNull(proc.Mp()); err != nil {
   375  							return err
   376  						}
   377  					}
   378  				}
   379  				rowCount++
   380  				continue
   381  			}
   382  		}
   383  
   384  		ctr.rbat.SetRowCount(ctr.rbat.RowCount() + rowCount)
   385  	}
   386  	anal.Output(ctr.rbat, isLast)
   387  	result.Batch = ctr.rbat
   388  	ap.lastrow = 0
   389  	return nil
   390  }
   391  
   392  func (ctr *container) evalJoinCondition(bat *batch.Batch, proc *process.Process) error {
   393  	for i := range ctr.evecs {
   394  		vec, err := ctr.evecs[i].executor.Eval(proc, []*batch.Batch{bat})
   395  		if err != nil {
   396  			return err
   397  		}
   398  		ctr.vecs[i] = vec
   399  		ctr.evecs[i].vec = vec
   400  	}
   401  	return nil
   402  }