github.com/matrixorigin/matrixone@v1.2.0/pkg/sql/colexec/loopmark/join.go (about)

     1  // Copyright 2021 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package loopmark
    16  
    17  import (
    18  	"bytes"
    19  
    20  	"github.com/matrixorigin/matrixone/pkg/common/moerr"
    21  	"github.com/matrixorigin/matrixone/pkg/container/batch"
    22  	"github.com/matrixorigin/matrixone/pkg/container/types"
    23  	"github.com/matrixorigin/matrixone/pkg/container/vector"
    24  	"github.com/matrixorigin/matrixone/pkg/sql/colexec"
    25  	"github.com/matrixorigin/matrixone/pkg/vm"
    26  	"github.com/matrixorigin/matrixone/pkg/vm/process"
    27  )
    28  
    29  const argName = "loop_mark"
    30  
    31  func (arg *Argument) String(buf *bytes.Buffer) {
    32  	buf.WriteString(argName)
    33  	buf.WriteString(": loop mark join ")
    34  }
    35  
    36  func (arg *Argument) Prepare(proc *process.Process) error {
    37  	var err error
    38  
    39  	arg.ctr = new(container)
    40  	arg.ctr.InitReceiver(proc, false)
    41  	arg.ctr.bat = batch.NewWithSize(len(arg.Typs))
    42  	for i, typ := range arg.Typs {
    43  		arg.ctr.bat.Vecs[i] = proc.GetVector(typ)
    44  	}
    45  
    46  	if arg.Cond != nil {
    47  		arg.ctr.expr, err = colexec.NewExpressionExecutor(proc, arg.Cond)
    48  	}
    49  	return err
    50  }
    51  
    52  func (arg *Argument) Call(proc *process.Process) (vm.CallResult, error) {
    53  	if err, isCancel := vm.CancelCheck(proc); isCancel {
    54  		return vm.CancelResult, err
    55  	}
    56  
    57  	anal := proc.GetAnalyze(arg.GetIdx(), arg.GetParallelIdx(), arg.GetParallelMajor())
    58  	anal.Start()
    59  	defer anal.Stop()
    60  	ctr := arg.ctr
    61  	result := vm.NewCallResult()
    62  	for {
    63  		switch ctr.state {
    64  		case Build:
    65  			if err := ctr.build(proc, anal); err != nil {
    66  				return result, err
    67  			}
    68  			ctr.state = Probe
    69  
    70  		case Probe:
    71  			var err error
    72  			bat, _, err := ctr.ReceiveFromSingleReg(0, anal)
    73  			if err != nil {
    74  				return result, err
    75  			}
    76  
    77  			if bat == nil {
    78  				ctr.state = End
    79  				continue
    80  			}
    81  			if bat.IsEmpty() {
    82  				proc.PutBatch(bat)
    83  				continue
    84  			}
    85  			if ctr.bat.RowCount() == 0 {
    86  				err = ctr.emptyProbe(bat, arg, proc, anal, arg.GetIsFirst(), arg.GetIsLast(), &result)
    87  			} else {
    88  				err = ctr.probe(bat, arg, proc, anal, arg.GetIsFirst(), arg.GetIsLast(), &result)
    89  			}
    90  			proc.PutBatch(bat)
    91  			return result, err
    92  
    93  		default:
    94  			result.Batch = nil
    95  			result.Status = vm.ExecStop
    96  			return result, nil
    97  		}
    98  	}
    99  }
   100  
   101  func (ctr *container) build(proc *process.Process, anal process.Analyze) error {
   102  	for {
   103  		bat, _, err := ctr.ReceiveFromSingleReg(1, anal)
   104  		if err != nil {
   105  			return err
   106  		}
   107  		if bat == nil {
   108  			break
   109  		}
   110  		ctr.bat, err = ctr.bat.AppendWithCopy(proc.Ctx, proc.Mp(), bat)
   111  		if err != nil {
   112  			return err
   113  		}
   114  		proc.PutBatch(bat)
   115  	}
   116  	return nil
   117  }
   118  
   119  func (ctr *container) emptyProbe(bat *batch.Batch, ap *Argument, proc *process.Process, anal process.Analyze, isFirst bool, isLast bool, result *vm.CallResult) (err error) {
   120  	anal.Input(bat, isFirst)
   121  	if ctr.rbat != nil {
   122  		proc.PutBatch(ctr.rbat)
   123  		ctr.rbat = nil
   124  	}
   125  	ctr.rbat = batch.NewWithSize(len(ap.Result))
   126  	count := bat.RowCount()
   127  	for i, rp := range ap.Result {
   128  		if rp >= 0 {
   129  			typ := *bat.Vecs[rp].GetType()
   130  			ctr.rbat.Vecs[i] = proc.GetVector(typ)
   131  			if err = vector.GetUnionAllFunction(typ, proc.Mp())(ctr.rbat.Vecs[i], bat.Vecs[rp]); err != nil {
   132  				return err
   133  			}
   134  		} else {
   135  			if ctr.rbat.Vecs[i], err = vector.NewConstFixed(types.T_bool.ToType(), false, count, proc.Mp()); err != nil {
   136  				return err
   137  			}
   138  		}
   139  	}
   140  	ctr.rbat.AddRowCount(bat.RowCount())
   141  	anal.Output(ctr.rbat, isLast)
   142  
   143  	result.Batch = ctr.rbat
   144  	return nil
   145  }
   146  
   147  func (ctr *container) probe(bat *batch.Batch, ap *Argument, proc *process.Process, anal process.Analyze, isFirst bool, isLast bool, result *vm.CallResult) error {
   148  	anal.Input(bat, isFirst)
   149  	if ctr.rbat != nil {
   150  		proc.PutBatch(ctr.rbat)
   151  		ctr.rbat = nil
   152  	}
   153  	ctr.rbat = batch.NewWithSize(len(ap.Result))
   154  	markPos := -1
   155  	for i, pos := range ap.Result {
   156  		if pos == -1 {
   157  			ctr.rbat.Vecs[i] = proc.GetVector(types.T_bool.ToType())
   158  			if err := ctr.rbat.Vecs[i].PreExtend(bat.RowCount(), proc.Mp()); err != nil {
   159  				return err
   160  			}
   161  			markPos = i
   162  			break
   163  		}
   164  	}
   165  	if markPos == -1 {
   166  		return moerr.NewInternalError(proc.Ctx, "MARK join must output mark column")
   167  	}
   168  	count := bat.RowCount()
   169  	if ctr.joinBat == nil {
   170  		ctr.joinBat, ctr.cfs = colexec.NewJoinBatch(bat, proc.Mp())
   171  	}
   172  	for i := 0; i < count; i++ {
   173  		if err := colexec.SetJoinBatchValues(ctr.joinBat, bat, int64(i),
   174  			ctr.bat.RowCount(), ctr.cfs); err != nil {
   175  			return err
   176  		}
   177  		vec, err := ctr.expr.Eval(proc, []*batch.Batch{ctr.joinBat, ctr.bat})
   178  		if err != nil {
   179  			return err
   180  		}
   181  
   182  		rs := vector.GenerateFunctionFixedTypeParameter[bool](vec)
   183  		hasTrue := false
   184  		hasNull := false
   185  		if vec.IsConst() {
   186  			v, null := rs.GetValue(0)
   187  			if null {
   188  				hasNull = true
   189  			} else {
   190  				hasTrue = v
   191  			}
   192  		} else {
   193  			for j := uint64(0); j < uint64(vec.Length()); j++ {
   194  				val, null := rs.GetValue(j)
   195  				if null {
   196  					hasNull = true
   197  				} else if val {
   198  					hasTrue = true
   199  				}
   200  			}
   201  		}
   202  		if hasTrue {
   203  			err = vector.AppendFixed(ctr.rbat.Vecs[markPos], true, false, proc.Mp())
   204  		} else if hasNull {
   205  			err = vector.AppendFixed(ctr.rbat.Vecs[markPos], false, true, proc.Mp())
   206  		} else {
   207  			err = vector.AppendFixed(ctr.rbat.Vecs[markPos], false, false, proc.Mp())
   208  		}
   209  		if err != nil {
   210  			return err
   211  		}
   212  	}
   213  	for i, rp := range ap.Result {
   214  		if rp >= 0 {
   215  			ctr.rbat.Vecs[i] = bat.Vecs[rp]
   216  			bat.Vecs[rp] = nil
   217  		}
   218  	}
   219  	ctr.rbat.AddRowCount(bat.RowCount())
   220  	anal.Output(ctr.rbat, isLast)
   221  	result.Batch = ctr.rbat
   222  	return nil
   223  }