github.com/matrixorigin/matrixone@v1.2.0/pkg/sql/colexec/loopjoin/join.go (about)

     1  // Copyright 2021 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package loopjoin
    16  
    17  import (
    18  	"bytes"
    19  
    20  	"github.com/matrixorigin/matrixone/pkg/container/batch"
    21  	"github.com/matrixorigin/matrixone/pkg/container/vector"
    22  	"github.com/matrixorigin/matrixone/pkg/sql/colexec"
    23  	"github.com/matrixorigin/matrixone/pkg/vm"
    24  	"github.com/matrixorigin/matrixone/pkg/vm/process"
    25  )
    26  
    27  const argName = "loop_join"
    28  
    29  func (arg *Argument) String(buf *bytes.Buffer) {
    30  	buf.WriteString(argName)
    31  	buf.WriteString(": loop join ")
    32  }
    33  
    34  func (arg *Argument) Prepare(proc *process.Process) error {
    35  	var err error
    36  
    37  	arg.ctr = new(container)
    38  	arg.ctr.InitReceiver(proc, false)
    39  
    40  	if arg.Cond != nil {
    41  		arg.ctr.expr, err = colexec.NewExpressionExecutor(proc, arg.Cond)
    42  	}
    43  	return err
    44  }
    45  
    46  func (arg *Argument) Call(proc *process.Process) (vm.CallResult, error) {
    47  	if err, isCancel := vm.CancelCheck(proc); isCancel {
    48  		return vm.CancelResult, err
    49  	}
    50  
    51  	anal := proc.GetAnalyze(arg.GetIdx(), arg.GetParallelIdx(), arg.GetParallelMajor())
    52  	anal.Start()
    53  	defer anal.Stop()
    54  	ctr := arg.ctr
    55  	result := vm.NewCallResult()
    56  	var err error
    57  	for {
    58  		switch ctr.state {
    59  		case Build:
    60  			if err := ctr.build(proc, anal); err != nil {
    61  				return result, err
    62  			}
    63  			if ctr.bat == nil {
    64  				// for inner ,right and semi join, if hashmap is empty, we can finish this pipeline
    65  				ctr.state = End
    66  			} else {
    67  				ctr.state = Probe
    68  			}
    69  
    70  		case Probe:
    71  			if ctr.inBat != nil {
    72  				err = ctr.probe(arg, proc, anal, arg.GetIsLast(), &result)
    73  				return result, err
    74  			}
    75  			ctr.inBat, _, err = ctr.ReceiveFromSingleReg(0, anal)
    76  			if err != nil {
    77  				return result, err
    78  			}
    79  
    80  			if ctr.inBat == nil {
    81  				ctr.state = End
    82  				continue
    83  			}
    84  			if ctr.inBat.IsEmpty() {
    85  				proc.PutBatch(ctr.inBat)
    86  				ctr.inBat = nil
    87  				continue
    88  			}
    89  			if ctr.bat == nil || ctr.bat.RowCount() == 0 {
    90  				proc.PutBatch(ctr.inBat)
    91  				ctr.inBat = nil
    92  				continue
    93  			}
    94  			anal.Input(ctr.inBat, arg.GetIsFirst())
    95  			err = ctr.probe(arg, proc, anal, arg.GetIsLast(), &result)
    96  			return result, err
    97  		default:
    98  			result.Batch = nil
    99  			result.Status = vm.ExecStop
   100  			return result, nil
   101  		}
   102  	}
   103  }
   104  
   105  func (ctr *container) build(proc *process.Process, anal process.Analyze) error {
   106  	for {
   107  		bat, _, err := ctr.ReceiveFromSingleReg(1, anal)
   108  		if err != nil {
   109  			return err
   110  		}
   111  		if bat == nil {
   112  			break
   113  		}
   114  		ctr.bat, err = ctr.bat.AppendWithCopy(proc.Ctx, proc.Mp(), bat)
   115  		if err != nil {
   116  			return err
   117  		}
   118  		proc.PutBatch(bat)
   119  	}
   120  	return nil
   121  }
   122  
   123  func (ctr *container) probe(ap *Argument, proc *process.Process, anal process.Analyze, isLast bool, result *vm.CallResult) error {
   124  	if ctr.rbat != nil {
   125  		proc.PutBatch(ctr.rbat)
   126  		ctr.rbat = nil
   127  	}
   128  	ctr.rbat = batch.NewWithSize(len(ap.Result))
   129  	for i, rp := range ap.Result {
   130  		if rp.Rel == 0 {
   131  			ctr.rbat.Vecs[i] = proc.GetVector(*ctr.inBat.Vecs[rp.Pos].GetType())
   132  		} else {
   133  			ctr.rbat.Vecs[i] = proc.GetVector(*ctr.bat.Vecs[rp.Pos].GetType())
   134  		}
   135  	}
   136  	count := ctr.inBat.RowCount()
   137  	if ctr.joinBat == nil {
   138  		ctr.joinBat, ctr.cfs = colexec.NewJoinBatch(ctr.inBat, proc.Mp())
   139  	}
   140  
   141  	rowCountIncrease := 0
   142  	for i := ctr.probeIdx; i < count; i++ {
   143  		if err := colexec.SetJoinBatchValues(ctr.joinBat, ctr.inBat, int64(i),
   144  			ctr.bat.RowCount(), ctr.cfs); err != nil {
   145  			return err
   146  		}
   147  		vec, err := ctr.expr.Eval(proc, []*batch.Batch{ctr.joinBat, ctr.bat})
   148  		if err != nil {
   149  			return err
   150  		}
   151  
   152  		rs := vector.GenerateFunctionFixedTypeParameter[bool](vec)
   153  		if vec.IsConst() {
   154  			b, null := rs.GetValue(0)
   155  			if !null && b {
   156  				for j := 0; j < ctr.bat.RowCount(); j++ {
   157  					for k, rp := range ap.Result {
   158  						if rp.Rel == 0 {
   159  							if err = ctr.rbat.Vecs[k].UnionOne(ctr.inBat.Vecs[rp.Pos], int64(i), proc.Mp()); err != nil {
   160  								return err
   161  							}
   162  						} else {
   163  							if err = ctr.rbat.Vecs[k].UnionOne(ctr.bat.Vecs[rp.Pos], int64(j), proc.Mp()); err != nil {
   164  								return err
   165  							}
   166  						}
   167  					}
   168  					rowCountIncrease++
   169  				}
   170  			}
   171  		} else {
   172  			l := uint64(ctr.bat.RowCount())
   173  			for j := uint64(0); j < l; j++ {
   174  				b, null := rs.GetValue(j)
   175  				if !null && b {
   176  					for k, rp := range ap.Result {
   177  						if rp.Rel == 0 {
   178  							if err = ctr.rbat.Vecs[k].UnionOne(ctr.inBat.Vecs[rp.Pos], int64(i), proc.Mp()); err != nil {
   179  								return err
   180  							}
   181  						} else {
   182  							if err = ctr.rbat.Vecs[k].UnionOne(ctr.bat.Vecs[rp.Pos], int64(j), proc.Mp()); err != nil {
   183  								return err
   184  							}
   185  						}
   186  					}
   187  					rowCountIncrease++
   188  				}
   189  			}
   190  		}
   191  		if rowCountIncrease >= colexec.DefaultBatchSize {
   192  			anal.Output(ctr.rbat, isLast)
   193  			result.Batch = ctr.rbat
   194  			ctr.rbat.SetRowCount(rowCountIncrease)
   195  			ctr.probeIdx = i + 1
   196  			return nil
   197  		}
   198  	}
   199  
   200  	ctr.probeIdx = 0
   201  	ctr.rbat.SetRowCount(rowCountIncrease)
   202  	anal.Output(ctr.rbat, isLast)
   203  	result.Batch = ctr.rbat
   204  	proc.PutBatch(ctr.inBat)
   205  	ctr.inBat = nil
   206  	return nil
   207  }