github.com/matrixorigin/matrixone@v1.2.0/pkg/sql/colexec/loopsemi/join.go (about)

     1  // Copyright 2021 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package loopsemi
    16  
    17  import (
    18  	"bytes"
    19  
    20  	"github.com/matrixorigin/matrixone/pkg/container/batch"
    21  	"github.com/matrixorigin/matrixone/pkg/container/vector"
    22  	"github.com/matrixorigin/matrixone/pkg/sql/colexec"
    23  	"github.com/matrixorigin/matrixone/pkg/vm"
    24  	"github.com/matrixorigin/matrixone/pkg/vm/process"
    25  )
    26  
    27  const argName = "loop_semi"
    28  
    29  func (arg *Argument) String(buf *bytes.Buffer) {
    30  	buf.WriteString(argName)
    31  	buf.WriteString(": ⨝ ")
    32  }
    33  
    34  func (arg *Argument) Prepare(proc *process.Process) error {
    35  	var err error
    36  
    37  	arg.ctr = new(container)
    38  	arg.ctr.InitReceiver(proc, false)
    39  
    40  	if arg.Cond != nil {
    41  		arg.ctr.expr, err = colexec.NewExpressionExecutor(proc, arg.Cond)
    42  	}
    43  	return err
    44  }
    45  
    46  func (arg *Argument) Call(proc *process.Process) (vm.CallResult, error) {
    47  	if err, isCancel := vm.CancelCheck(proc); isCancel {
    48  		return vm.CancelResult, err
    49  	}
    50  
    51  	anal := proc.GetAnalyze(arg.GetIdx(), arg.GetParallelIdx(), arg.GetParallelMajor())
    52  	anal.Start()
    53  	defer anal.Stop()
    54  	ctr := arg.ctr
    55  	result := vm.NewCallResult()
    56  	for {
    57  		switch ctr.state {
    58  		case Build:
    59  			if err := ctr.build(proc, anal); err != nil {
    60  				return result, err
    61  			}
    62  			if ctr.bat == nil {
    63  				// for inner ,right and semi join, if hashmap is empty, we can finish this pipeline
    64  				ctr.state = End
    65  			} else {
    66  				ctr.state = Probe
    67  			}
    68  
    69  		case Probe:
    70  			if arg.bat == nil {
    71  				bat, _, err := ctr.ReceiveFromSingleReg(0, anal)
    72  				if err != nil {
    73  					return result, err
    74  				}
    75  
    76  				if bat == nil {
    77  					ctr.state = End
    78  					continue
    79  				}
    80  				if bat.IsEmpty() {
    81  					proc.PutBatch(bat)
    82  					continue
    83  				}
    84  				if ctr.bat == nil || ctr.bat.RowCount() == 0 {
    85  					proc.PutBatch(bat)
    86  					continue
    87  				}
    88  				arg.bat = bat
    89  				arg.lastrow = 0
    90  			}
    91  
    92  			err := ctr.probe(arg, proc, anal, arg.GetIsFirst(), arg.GetIsLast(), &result)
    93  			if arg.lastrow == 0 {
    94  				proc.PutBatch(arg.bat)
    95  				arg.bat = nil
    96  			}
    97  			return result, err
    98  
    99  		default:
   100  			result.Batch = nil
   101  			result.Status = vm.ExecStop
   102  			return result, nil
   103  		}
   104  	}
   105  }
   106  
   107  func (ctr *container) build(proc *process.Process, anal process.Analyze) error {
   108  	for {
   109  		bat, _, err := ctr.ReceiveFromSingleReg(1, anal)
   110  		if err != nil {
   111  			return err
   112  		}
   113  		if bat == nil {
   114  			break
   115  		}
   116  		ctr.bat, err = ctr.bat.AppendWithCopy(proc.Ctx, proc.Mp(), bat)
   117  		if err != nil {
   118  			return err
   119  		}
   120  		proc.PutBatch(bat)
   121  	}
   122  	return nil
   123  }
   124  
   125  func (ctr *container) probe(ap *Argument, proc *process.Process, anal process.Analyze, isFirst bool, isLast bool, result *vm.CallResult) error {
   126  	anal.Input(ap.bat, isFirst)
   127  	if ctr.rbat != nil {
   128  		proc.PutBatch(ctr.rbat)
   129  		ctr.rbat = nil
   130  	}
   131  	ctr.rbat = batch.NewWithSize(len(ap.Result))
   132  	for i, pos := range ap.Result {
   133  		ctr.rbat.Vecs[i] = proc.GetVector(*ap.bat.Vecs[pos].GetType())
   134  	}
   135  	count := ap.bat.RowCount()
   136  	if ctr.joinBat == nil {
   137  		ctr.joinBat, ctr.cfs = colexec.NewJoinBatch(ap.bat, proc.Mp())
   138  	}
   139  
   140  	rowCountIncrease := 0
   141  	for i := ap.lastrow; i < count; i++ {
   142  		if rowCountIncrease >= colexec.DefaultBatchSize {
   143  			ctr.rbat.SetRowCount(ctr.rbat.RowCount() + rowCountIncrease)
   144  			anal.Output(ctr.rbat, isLast)
   145  			result.Batch = ctr.rbat
   146  			ap.lastrow = i
   147  			return nil
   148  		}
   149  		if err := colexec.SetJoinBatchValues(ctr.joinBat, ap.bat, int64(i),
   150  			ctr.bat.RowCount(), ctr.cfs); err != nil {
   151  			return err
   152  		}
   153  		vec, err := ctr.expr.Eval(proc, []*batch.Batch{ctr.joinBat, ctr.bat})
   154  		if err != nil {
   155  			return err
   156  		}
   157  
   158  		rs := vector.GenerateFunctionFixedTypeParameter[bool](vec)
   159  		for k := uint64(0); k < uint64(vec.Length()); k++ {
   160  			b, null := rs.GetValue(k)
   161  			if !null && b {
   162  				for k, pos := range ap.Result {
   163  					if err = ctr.rbat.Vecs[k].UnionOne(ap.bat.Vecs[pos], int64(i), proc.Mp()); err != nil {
   164  						return err
   165  					}
   166  				}
   167  				rowCountIncrease++
   168  				break
   169  			}
   170  		}
   171  	}
   172  	ctr.rbat.SetRowCount(ctr.rbat.RowCount() + rowCountIncrease)
   173  	anal.Output(ctr.rbat, isLast)
   174  	result.Batch = ctr.rbat
   175  	ap.lastrow = 0
   176  	return nil
   177  }