github.com/matrixorigin/matrixone@v1.2.0/pkg/sql/colexec/loopsingle/join.go (about)

     1  // Copyright 2021 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package loopsingle
    16  
    17  import (
    18  	"bytes"
    19  
    20  	"github.com/matrixorigin/matrixone/pkg/common/moerr"
    21  	"github.com/matrixorigin/matrixone/pkg/container/batch"
    22  	"github.com/matrixorigin/matrixone/pkg/container/vector"
    23  	"github.com/matrixorigin/matrixone/pkg/sql/colexec"
    24  	"github.com/matrixorigin/matrixone/pkg/vm"
    25  	"github.com/matrixorigin/matrixone/pkg/vm/process"
    26  )
    27  
    28  const argName = "loop_single"
    29  
    30  func (arg *Argument) String(buf *bytes.Buffer) {
    31  	buf.WriteString(argName)
    32  	buf.WriteString(": loop single join ")
    33  }
    34  
    35  func (arg *Argument) Prepare(proc *process.Process) error {
    36  	var err error
    37  
    38  	arg.ctr = new(container)
    39  	arg.ctr.InitReceiver(proc, false)
    40  	arg.ctr.bat = batch.NewWithSize(len(arg.Typs))
    41  	for i, typ := range arg.Typs {
    42  		arg.ctr.bat.Vecs[i] = proc.GetVector(typ)
    43  	}
    44  
    45  	if arg.Cond != nil {
    46  		arg.ctr.expr, err = colexec.NewExpressionExecutor(proc, arg.Cond)
    47  	}
    48  	return err
    49  }
    50  
    51  func (arg *Argument) Call(proc *process.Process) (vm.CallResult, error) {
    52  	if err, isCancel := vm.CancelCheck(proc); isCancel {
    53  		return vm.CancelResult, err
    54  	}
    55  
    56  	anal := proc.GetAnalyze(arg.GetIdx(), arg.GetParallelIdx(), arg.GetParallelMajor())
    57  	anal.Start()
    58  	defer anal.Stop()
    59  	ctr := arg.ctr
    60  	result := vm.NewCallResult()
    61  	for {
    62  		switch ctr.state {
    63  		case Build:
    64  			if err := ctr.build(proc, anal); err != nil {
    65  				return result, err
    66  			}
    67  			ctr.state = Probe
    68  
    69  		case Probe:
    70  			var err error
    71  			bat, _, err := ctr.ReceiveFromSingleReg(0, anal)
    72  			if err != nil {
    73  				return result, err
    74  			}
    75  
    76  			if bat == nil {
    77  				ctr.state = End
    78  				continue
    79  			}
    80  			if bat.IsEmpty() {
    81  				proc.PutBatch(bat)
    82  				continue
    83  			}
    84  			if ctr.bat.RowCount() == 0 {
    85  				err = ctr.emptyProbe(bat, arg, proc, anal, arg.GetIsFirst(), arg.GetIsLast(), &result)
    86  			} else {
    87  				err = ctr.probe(bat, arg, proc, anal, arg.GetIsFirst(), arg.GetIsLast(), &result)
    88  			}
    89  			proc.PutBatch(bat)
    90  
    91  			return result, err
    92  
    93  		default:
    94  			result.Batch = nil
    95  			result.Status = vm.ExecStop
    96  			return result, nil
    97  		}
    98  	}
    99  }
   100  
   101  func (ctr *container) build(proc *process.Process, anal process.Analyze) error {
   102  	for {
   103  		bat, _, err := ctr.ReceiveFromSingleReg(1, anal)
   104  		if err != nil {
   105  			return err
   106  		}
   107  		if bat == nil {
   108  			break
   109  		}
   110  		ctr.bat, err = ctr.bat.AppendWithCopy(proc.Ctx, proc.Mp(), bat)
   111  		if err != nil {
   112  			return err
   113  		}
   114  		proc.PutBatch(bat)
   115  	}
   116  	return nil
   117  }
   118  
   119  func (ctr *container) emptyProbe(bat *batch.Batch, ap *Argument, proc *process.Process, anal process.Analyze, isFirst bool, isLast bool, result *vm.CallResult) error {
   120  	anal.Input(bat, isFirst)
   121  	if ctr.rbat != nil {
   122  		proc.PutBatch(ctr.rbat)
   123  		ctr.rbat = nil
   124  	}
   125  	ctr.rbat = batch.NewWithSize(len(ap.Result))
   126  	for i, rp := range ap.Result {
   127  		if rp.Rel == 0 {
   128  			ctr.rbat.Vecs[i] = bat.Vecs[rp.Pos]
   129  			bat.Vecs[rp.Pos] = nil
   130  		} else {
   131  			ctr.rbat.Vecs[i] = vector.NewConstNull(ap.Typs[rp.Pos], bat.RowCount(), proc.Mp())
   132  		}
   133  	}
   134  	ctr.rbat.SetRowCount(ctr.rbat.RowCount() + bat.RowCount())
   135  	anal.Output(ctr.rbat, isLast)
   136  	result.Batch = ctr.rbat
   137  	return nil
   138  }
   139  
   140  func (ctr *container) probe(bat *batch.Batch, ap *Argument, proc *process.Process, anal process.Analyze, isFirst bool, isLast bool, result *vm.CallResult) error {
   141  	anal.Input(bat, isFirst)
   142  	if ctr.rbat != nil {
   143  		proc.PutBatch(ctr.rbat)
   144  		ctr.rbat = nil
   145  	}
   146  	ctr.rbat = batch.NewWithSize(len(ap.Result))
   147  	for i, rp := range ap.Result {
   148  		if rp.Rel != 0 {
   149  			ctr.rbat.Vecs[i] = proc.GetVector(ap.Typs[rp.Pos])
   150  		}
   151  	}
   152  	count := bat.RowCount()
   153  	if ctr.expr == nil {
   154  		switch ctr.bat.RowCount() {
   155  		case 0:
   156  			for i, rp := range ap.Result {
   157  				if rp.Rel != 0 {
   158  					err := vector.AppendMultiFixed(ctr.rbat.Vecs[i], 0, true, count, proc.Mp())
   159  					if err != nil {
   160  						return err
   161  					}
   162  				}
   163  			}
   164  		case 1:
   165  			for i, rp := range ap.Result {
   166  				if rp.Rel != 0 {
   167  					err := ctr.rbat.Vecs[i].UnionMulti(ctr.bat.Vecs[rp.Pos], 0, count, proc.Mp())
   168  					if err != nil {
   169  						return err
   170  					}
   171  				}
   172  			}
   173  		default:
   174  			return moerr.NewInternalError(proc.Ctx, "scalar subquery returns more than 1 row")
   175  		}
   176  	} else {
   177  		if ctr.joinBat == nil {
   178  			ctr.joinBat, ctr.cfs = colexec.NewJoinBatch(bat, proc.Mp())
   179  		}
   180  		for i := 0; i < count; i++ {
   181  			if err := colexec.SetJoinBatchValues(ctr.joinBat, bat, int64(i),
   182  				ctr.bat.RowCount(), ctr.cfs); err != nil {
   183  				return err
   184  			}
   185  			unmatched := true
   186  			vec, err := ctr.expr.Eval(proc, []*batch.Batch{ctr.joinBat, ctr.bat})
   187  			if err != nil {
   188  				return err
   189  			}
   190  
   191  			rs := vector.GenerateFunctionFixedTypeParameter[bool](vec)
   192  			if vec.IsConst() {
   193  				b, null := rs.GetValue(0)
   194  				if !null && b {
   195  					if ctr.bat.RowCount() > 1 {
   196  						return moerr.NewInternalError(proc.Ctx, "scalar subquery returns more than 1 row")
   197  					}
   198  					unmatched = false
   199  					for k, rp := range ap.Result {
   200  						if rp.Rel != 0 {
   201  							if err := ctr.rbat.Vecs[k].UnionOne(ctr.bat.Vecs[rp.Pos], 0, proc.Mp()); err != nil {
   202  								return err
   203  							}
   204  						}
   205  					}
   206  				}
   207  			} else {
   208  				l := vec.Length()
   209  				for j := uint64(0); j < uint64(l); j++ {
   210  					b, null := rs.GetValue(j)
   211  					if !null && b {
   212  						if !unmatched {
   213  							return moerr.NewInternalError(proc.Ctx, "scalar subquery returns more than 1 row")
   214  						}
   215  						unmatched = false
   216  						for k, rp := range ap.Result {
   217  							if rp.Rel != 0 {
   218  								if err := ctr.rbat.Vecs[k].UnionOne(ctr.bat.Vecs[rp.Pos], int64(j), proc.Mp()); err != nil {
   219  									return err
   220  								}
   221  							}
   222  						}
   223  					}
   224  				}
   225  			}
   226  			if unmatched {
   227  				for k, rp := range ap.Result {
   228  					if rp.Rel != 0 {
   229  						if err := ctr.rbat.Vecs[k].UnionNull(proc.Mp()); err != nil {
   230  							return err
   231  						}
   232  					}
   233  				}
   234  			}
   235  		}
   236  	}
   237  	for i, rp := range ap.Result {
   238  		if rp.Rel == 0 {
   239  			// rbat.Vecs[i] = bat.Vecs[rp.Pos]
   240  			// bat.Vecs[rp.Pos] = nil
   241  			typ := *bat.Vecs[rp.Pos].GetType()
   242  			ctr.rbat.Vecs[i] = proc.GetVector(typ)
   243  			if err := vector.GetUnionAllFunction(typ, proc.Mp())(ctr.rbat.Vecs[i], bat.Vecs[rp.Pos]); err != nil {
   244  				return err
   245  			}
   246  		}
   247  	}
   248  	ctr.rbat.AddRowCount(bat.RowCount())
   249  	anal.Output(ctr.rbat, isLast)
   250  	result.Batch = ctr.rbat
   251  	return nil
   252  }