github.com/matrixorigin/matrixone@v1.2.0/pkg/sql/colexec/loopanti/join.go (about)

     1  // Copyright 2021 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package loopanti
    16  
    17  import (
    18  	"bytes"
    19  
    20  	"github.com/matrixorigin/matrixone/pkg/container/batch"
    21  	"github.com/matrixorigin/matrixone/pkg/container/nulls"
    22  	"github.com/matrixorigin/matrixone/pkg/container/vector"
    23  	"github.com/matrixorigin/matrixone/pkg/sql/colexec"
    24  	"github.com/matrixorigin/matrixone/pkg/vm"
    25  	"github.com/matrixorigin/matrixone/pkg/vm/process"
    26  )
    27  
    28  const argName = "loop_anti"
    29  
    30  func (arg *Argument) String(buf *bytes.Buffer) {
    31  	buf.WriteString(argName)
    32  	buf.WriteString(": loop anti join ")
    33  }
    34  
    35  func (arg *Argument) Prepare(proc *process.Process) error {
    36  	var err error
    37  
    38  	arg.ctr = new(container)
    39  	arg.ctr.InitReceiver(proc, false)
    40  
    41  	if arg.Cond != nil {
    42  		arg.ctr.expr, err = colexec.NewExpressionExecutor(proc, arg.Cond)
    43  	}
    44  	return err
    45  }
    46  
    47  func (arg *Argument) Call(proc *process.Process) (vm.CallResult, error) {
    48  	if err, isCancel := vm.CancelCheck(proc); isCancel {
    49  		return vm.CancelResult, err
    50  	}
    51  
    52  	anal := proc.GetAnalyze(arg.GetIdx(), arg.GetParallelIdx(), arg.GetParallelMajor())
    53  	anal.Start()
    54  	defer anal.Stop()
    55  	ctr := arg.ctr
    56  	result := vm.NewCallResult()
    57  	for {
    58  		switch ctr.state {
    59  		case Build:
    60  			if err := ctr.build(proc, anal); err != nil {
    61  				return result, err
    62  			}
    63  			ctr.state = Probe
    64  
    65  		case Probe:
    66  			var err error
    67  			if arg.bat == nil {
    68  				arg.bat, _, err = ctr.ReceiveFromSingleReg(0, anal)
    69  				if err != nil {
    70  					return result, err
    71  				}
    72  
    73  				if arg.bat == nil {
    74  					ctr.state = End
    75  					continue
    76  				}
    77  				if arg.bat.RowCount() == 0 {
    78  					proc.PutBatch(arg.bat)
    79  					arg.bat = nil
    80  					continue
    81  				}
    82  				arg.lastrow = 0
    83  			}
    84  
    85  			if ctr.bat == nil || ctr.bat.RowCount() == 0 {
    86  				err = ctr.emptyProbe(arg, proc, anal, arg.GetIsFirst(), arg.GetIsLast(), &result)
    87  			} else {
    88  				err = ctr.probe(arg, proc, anal, arg.GetIsFirst(), arg.GetIsLast(), &result)
    89  			}
    90  			if arg.lastrow == 0 {
    91  				proc.PutBatch(arg.bat)
    92  				arg.bat = nil
    93  			}
    94  			return result, err
    95  		default:
    96  			result.Batch = nil
    97  			result.Status = vm.ExecStop
    98  			return result, nil
    99  		}
   100  	}
   101  }
   102  
   103  func (ctr *container) build(proc *process.Process, anal process.Analyze) error {
   104  	for {
   105  		bat, _, err := ctr.ReceiveFromSingleReg(1, anal)
   106  		if err != nil {
   107  			return err
   108  		}
   109  		if bat == nil {
   110  			break
   111  		}
   112  		ctr.bat, err = ctr.bat.AppendWithCopy(proc.Ctx, proc.Mp(), bat)
   113  		if err != nil {
   114  			return err
   115  		}
   116  		proc.PutBatch(bat)
   117  	}
   118  	return nil
   119  }
   120  
   121  func (ctr *container) emptyProbe(ap *Argument, proc *process.Process, anal process.Analyze, isFirst bool, isLast bool, result *vm.CallResult) error {
   122  	anal.Input(ap.bat, isFirst)
   123  	if ctr.rbat != nil {
   124  		proc.PutBatch(ctr.rbat)
   125  		ctr.rbat = nil
   126  	}
   127  	ctr.rbat = batch.NewWithSize(len(ap.Result))
   128  	for i, pos := range ap.Result {
   129  		// rbat.Vecs[i] = bat.Vecs[pos]
   130  		// bat.Vecs[pos] = nil
   131  		typ := *ap.bat.Vecs[pos].GetType()
   132  		ctr.rbat.Vecs[i] = proc.GetVector(typ)
   133  		if err := vector.GetUnionAllFunction(typ, proc.Mp())(ctr.rbat.Vecs[i], ap.bat.Vecs[pos]); err != nil {
   134  			return err
   135  		}
   136  	}
   137  	ctr.rbat.AddRowCount(ap.bat.RowCount())
   138  	anal.Output(ctr.rbat, isLast)
   139  	result.Batch = ctr.rbat
   140  	ap.lastrow = 0
   141  	return nil
   142  }
   143  
   144  func (ctr *container) probe(ap *Argument, proc *process.Process, anal process.Analyze, isFirst bool, isLast bool, result *vm.CallResult) error {
   145  	anal.Input(ap.bat, isFirst)
   146  	if ctr.rbat != nil {
   147  		proc.PutBatch(ctr.rbat)
   148  		ctr.rbat = nil
   149  	}
   150  	ctr.rbat = batch.NewWithSize(len(ap.Result))
   151  	for i, pos := range ap.Result {
   152  		ctr.rbat.Vecs[i] = proc.GetVector(*ap.bat.Vecs[pos].GetType())
   153  	}
   154  	count := ap.bat.RowCount()
   155  	if ctr.joinBat == nil {
   156  		ctr.joinBat, ctr.cfs = colexec.NewJoinBatch(ap.bat, proc.Mp())
   157  	}
   158  
   159  	rowCountIncrease := 0
   160  	for i := ap.lastrow; i < count; i++ {
   161  		if rowCountIncrease >= colexec.DefaultBatchSize {
   162  			ctr.rbat.SetRowCount(ctr.rbat.RowCount() + rowCountIncrease)
   163  			anal.Output(ctr.rbat, isLast)
   164  			result.Batch = ctr.rbat
   165  			ap.lastrow = i
   166  			return nil
   167  		}
   168  		if err := colexec.SetJoinBatchValues(ctr.joinBat, ap.bat, int64(i),
   169  			ctr.bat.RowCount(), ctr.cfs); err != nil {
   170  			return err
   171  		}
   172  		matched := false
   173  		vec, err := ctr.expr.Eval(proc, []*batch.Batch{ctr.joinBat, ctr.bat})
   174  		if err != nil {
   175  			return err
   176  		}
   177  
   178  		rs := vector.GenerateFunctionFixedTypeParameter[bool](vec)
   179  		for k := uint64(0); k < uint64(vec.Length()); k++ {
   180  			b, null := rs.GetValue(k)
   181  			if !null && b {
   182  				matched = true
   183  				break
   184  			}
   185  		}
   186  		if !matched && !nulls.Any(vec.GetNulls()) {
   187  			for k, pos := range ap.Result {
   188  				if err := ctr.rbat.Vecs[k].UnionOne(ap.bat.Vecs[pos], int64(i), proc.Mp()); err != nil {
   189  					return err
   190  				}
   191  			}
   192  			rowCountIncrease++
   193  		}
   194  	}
   195  	ctr.rbat.SetRowCount(ctr.rbat.RowCount() + rowCountIncrease)
   196  	anal.Output(ctr.rbat, isLast)
   197  
   198  	result.Batch = ctr.rbat
   199  	ap.lastrow = 0
   200  	return nil
   201  }