github.com/matrixorigin/matrixone@v1.2.0/pkg/sql/colexec/single/join.go (about)

     1  // Copyright 2021 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package single
    16  
    17  import (
    18  	"bytes"
    19  
    20  	"github.com/matrixorigin/matrixone/pkg/common/hashmap"
    21  	"github.com/matrixorigin/matrixone/pkg/common/moerr"
    22  	"github.com/matrixorigin/matrixone/pkg/container/batch"
    23  	"github.com/matrixorigin/matrixone/pkg/container/vector"
    24  	"github.com/matrixorigin/matrixone/pkg/sql/colexec"
    25  	"github.com/matrixorigin/matrixone/pkg/vm"
    26  	"github.com/matrixorigin/matrixone/pkg/vm/process"
    27  )
    28  
    29  const argName = "single"
    30  
    31  func (arg *Argument) String(buf *bytes.Buffer) {
    32  	buf.WriteString(argName)
    33  	buf.WriteString(": single join ")
    34  }
    35  
    36  func (arg *Argument) Prepare(proc *process.Process) (err error) {
    37  	ap := arg
    38  	ap.ctr = new(container)
    39  	ap.ctr.InitReceiver(proc, false)
    40  	ap.ctr.inBuckets = make([]uint8, hashmap.UnitLimit)
    41  	ap.ctr.vecs = make([]*vector.Vector, len(ap.Conditions[0]))
    42  
    43  	ap.ctr.evecs = make([]evalVector, len(ap.Conditions[0]))
    44  	for i := range ap.ctr.evecs {
    45  		ap.ctr.evecs[i].executor, err = colexec.NewExpressionExecutor(proc, ap.Conditions[0][i])
    46  		if err != nil {
    47  			return err
    48  		}
    49  	}
    50  
    51  	if ap.Cond != nil {
    52  		ap.ctr.expr, err = colexec.NewExpressionExecutor(proc, ap.Cond)
    53  	}
    54  	return err
    55  }
    56  
    57  func (arg *Argument) Call(proc *process.Process) (vm.CallResult, error) {
    58  	if err, isCancel := vm.CancelCheck(proc); isCancel {
    59  		return vm.CancelResult, err
    60  	}
    61  
    62  	anal := proc.GetAnalyze(arg.GetIdx(), arg.GetParallelIdx(), arg.GetParallelMajor())
    63  	anal.Start()
    64  	defer anal.Stop()
    65  	ap := arg
    66  	ctr := ap.ctr
    67  	result := vm.NewCallResult()
    68  	for {
    69  		switch ctr.state {
    70  		case Build:
    71  			if err := ctr.build(anal); err != nil {
    72  				return result, err
    73  			}
    74  			ctr.state = Probe
    75  
    76  		case Probe:
    77  			bat, _, err := ctr.ReceiveFromSingleReg(0, anal)
    78  			if err != nil {
    79  				return result, err
    80  			}
    81  
    82  			if bat == nil {
    83  				ctr.state = End
    84  				continue
    85  			}
    86  			if bat.Last() {
    87  				result.Batch = bat
    88  				return result, nil
    89  			}
    90  			if bat.IsEmpty() {
    91  				proc.PutBatch(bat)
    92  				continue
    93  			}
    94  			if ctr.mp == nil {
    95  				if err := ctr.emptyProbe(bat, ap, proc, anal, arg.GetIsFirst(), arg.GetIsLast(), &result); err != nil {
    96  					bat.Clean(proc.Mp())
    97  					result.Status = vm.ExecStop
    98  					return result, err
    99  				}
   100  			} else {
   101  				if err := ctr.probe(bat, ap, proc, anal, arg.GetIsFirst(), arg.GetIsLast(), &result); err != nil {
   102  					bat.Clean(proc.Mp())
   103  					result.Status = vm.ExecStop
   104  					return result, err
   105  				}
   106  			}
   107  			proc.PutBatch(bat)
   108  			return result, nil
   109  
   110  		default:
   111  			result.Batch = nil
   112  			result.Status = vm.ExecStop
   113  			return result, nil
   114  		}
   115  	}
   116  }
   117  
   118  func (ctr *container) receiveHashMap(anal process.Analyze) error {
   119  	bat, _, err := ctr.ReceiveFromSingleReg(1, anal)
   120  	if err != nil {
   121  		return err
   122  	}
   123  	if bat != nil && bat.AuxData != nil {
   124  		ctr.mp = bat.DupJmAuxData()
   125  		ctr.maxAllocSize = max(ctr.maxAllocSize, ctr.mp.Size())
   126  	}
   127  	return nil
   128  }
   129  
   130  func (ctr *container) receiveBatch(anal process.Analyze) error {
   131  	for {
   132  		bat, _, err := ctr.ReceiveFromSingleReg(1, anal)
   133  		if err != nil {
   134  			return err
   135  		}
   136  		if bat != nil {
   137  			ctr.batchRowCount += bat.RowCount()
   138  			ctr.batches = append(ctr.batches, bat)
   139  		} else {
   140  			break
   141  		}
   142  	}
   143  	for i := 0; i < len(ctr.batches)-1; i++ {
   144  		if ctr.batches[i].RowCount() != colexec.DefaultBatchSize {
   145  			panic("wrong batch received for hash build!")
   146  		}
   147  	}
   148  	return nil
   149  }
   150  
   151  func (ctr *container) build(anal process.Analyze) error {
   152  	err := ctr.receiveHashMap(anal)
   153  	if err != nil {
   154  		return err
   155  	}
   156  	return ctr.receiveBatch(anal)
   157  }
   158  
   159  func (ctr *container) emptyProbe(bat *batch.Batch, ap *Argument, proc *process.Process, anal process.Analyze, isFirst bool, isLast bool, result *vm.CallResult) error {
   160  	anal.Input(bat, isFirst)
   161  	if ctr.rbat != nil {
   162  		proc.PutBatch(ctr.rbat)
   163  		ctr.rbat = nil
   164  	}
   165  	ctr.rbat = batch.NewWithSize(len(ap.Result))
   166  	for i, rp := range ap.Result {
   167  		if rp.Rel == 0 {
   168  			ctr.rbat.Vecs[i] = bat.Vecs[rp.Pos]
   169  			bat.Vecs[rp.Pos] = nil
   170  		} else {
   171  			ctr.rbat.Vecs[i] = vector.NewConstNull(ap.Typs[rp.Pos], bat.RowCount(), proc.Mp())
   172  		}
   173  	}
   174  	ctr.rbat.AddRowCount(bat.RowCount())
   175  	anal.Output(ctr.rbat, isLast)
   176  	result.Batch = ctr.rbat
   177  	return nil
   178  }
   179  
   180  func (ctr *container) probe(bat *batch.Batch, ap *Argument, proc *process.Process, anal process.Analyze, isFirst bool, isLast bool, result *vm.CallResult) error {
   181  	anal.Input(bat, isFirst)
   182  	if ctr.rbat != nil {
   183  		proc.PutBatch(ctr.rbat)
   184  		ctr.rbat = nil
   185  	}
   186  	ctr.rbat = batch.NewWithSize(len(ap.Result))
   187  	for i, rp := range ap.Result {
   188  		if rp.Rel != 0 {
   189  			ctr.rbat.Vecs[i] = proc.GetVector(ap.Typs[rp.Pos])
   190  		}
   191  	}
   192  
   193  	if err := ctr.evalJoinCondition(bat, proc); err != nil {
   194  		return err
   195  	}
   196  
   197  	if ctr.joinBat1 == nil {
   198  		ctr.joinBat1, ctr.cfs1 = colexec.NewJoinBatch(bat, proc.Mp())
   199  	}
   200  	if ctr.joinBat2 == nil {
   201  		ctr.joinBat2, ctr.cfs2 = colexec.NewJoinBatch(ctr.batches[0], proc.Mp())
   202  	}
   203  
   204  	count := bat.RowCount()
   205  	mSels := ctr.mp.Sels()
   206  	itr := ctr.mp.NewIterator()
   207  	for i := 0; i < count; i += hashmap.UnitLimit {
   208  		n := count - i
   209  		if n > hashmap.UnitLimit {
   210  			n = hashmap.UnitLimit
   211  		}
   212  		copy(ctr.inBuckets, hashmap.OneUInt8s)
   213  		vals, zvals := itr.Find(i, n, ctr.vecs, ctr.inBuckets)
   214  		for k := 0; k < n; k++ {
   215  			if ctr.inBuckets[k] == 0 {
   216  				continue
   217  			}
   218  			if zvals[k] == 0 || vals[k] == 0 {
   219  				for j, rp := range ap.Result {
   220  					if rp.Rel != 0 {
   221  						if err := ctr.rbat.Vecs[j].UnionNull(proc.Mp()); err != nil {
   222  							return err
   223  						}
   224  					}
   225  				}
   226  				continue
   227  			}
   228  			if ap.HashOnPK {
   229  				idx1, idx2 := int64(vals[k]-1)/colexec.DefaultBatchSize, int64(vals[k]-1)%colexec.DefaultBatchSize
   230  				matched := false
   231  				if ap.Cond != nil {
   232  					if err := colexec.SetJoinBatchValues(ctr.joinBat1, bat, int64(i+k),
   233  						1, ctr.cfs1); err != nil {
   234  						return err
   235  					}
   236  					if err := colexec.SetJoinBatchValues(ctr.joinBat2, ctr.batches[idx1], idx2,
   237  						1, ctr.cfs2); err != nil {
   238  						return err
   239  					}
   240  					vec, err := ctr.expr.Eval(proc, []*batch.Batch{ctr.joinBat1, ctr.joinBat2})
   241  					if err != nil {
   242  						return err
   243  					}
   244  					if vec.IsConstNull() || vec.GetNulls().Contains(0) {
   245  						continue
   246  					}
   247  					bs := vector.MustFixedCol[bool](vec)
   248  					if bs[0] {
   249  						if matched {
   250  							return moerr.NewInternalError(proc.Ctx, "scalar subquery returns more than 1 row")
   251  						}
   252  						matched = true
   253  					}
   254  				}
   255  				if ap.Cond != nil && !matched {
   256  					for j, rp := range ap.Result {
   257  						if rp.Rel != 0 {
   258  							if err := ctr.rbat.Vecs[j].UnionNull(proc.Mp()); err != nil {
   259  								return err
   260  							}
   261  						}
   262  					}
   263  					continue
   264  				}
   265  				for j, rp := range ap.Result {
   266  					if rp.Rel != 0 {
   267  						if err := ctr.rbat.Vecs[j].UnionOne(ctr.batches[idx1].Vecs[rp.Pos], idx2, proc.Mp()); err != nil {
   268  							return err
   269  						}
   270  					}
   271  				}
   272  			} else {
   273  				idx := 0
   274  				matched := false
   275  				sels := mSels[vals[k]-1]
   276  				if ap.Cond != nil {
   277  					for j, sel := range sels {
   278  						idx1, idx2 := sel/colexec.DefaultBatchSize, sel%colexec.DefaultBatchSize
   279  						if err := colexec.SetJoinBatchValues(ctr.joinBat1, bat, int64(i+k),
   280  							1, ctr.cfs1); err != nil {
   281  							return err
   282  						}
   283  						if err := colexec.SetJoinBatchValues(ctr.joinBat2, ctr.batches[idx1], int64(idx2),
   284  							1, ctr.cfs2); err != nil {
   285  							return err
   286  						}
   287  						vec, err := ctr.expr.Eval(proc, []*batch.Batch{ctr.joinBat1, ctr.joinBat2})
   288  						if err != nil {
   289  							return err
   290  						}
   291  						if vec.IsConstNull() || vec.GetNulls().Contains(0) {
   292  							continue
   293  						}
   294  						bs := vector.MustFixedCol[bool](vec)
   295  						if bs[0] {
   296  							if matched {
   297  								return moerr.NewInternalError(proc.Ctx, "scalar subquery returns more than 1 row")
   298  							}
   299  							matched = true
   300  							idx = j
   301  						}
   302  					}
   303  				} else if len(sels) > 1 {
   304  					return moerr.NewInternalError(proc.Ctx, "scalar subquery returns more than 1 row")
   305  				}
   306  				if ap.Cond != nil && !matched {
   307  					for j, rp := range ap.Result {
   308  						if rp.Rel != 0 {
   309  							if err := ctr.rbat.Vecs[j].UnionNull(proc.Mp()); err != nil {
   310  								return err
   311  							}
   312  						}
   313  					}
   314  					continue
   315  				}
   316  				sel := sels[idx]
   317  				for j, rp := range ap.Result {
   318  					if rp.Rel != 0 {
   319  						idx1, idx2 := sel/colexec.DefaultBatchSize, sel%colexec.DefaultBatchSize
   320  						if err := ctr.rbat.Vecs[j].UnionOne(ctr.batches[idx1].Vecs[rp.Pos], int64(idx2), proc.Mp()); err != nil {
   321  							return err
   322  						}
   323  					}
   324  				}
   325  			}
   326  		}
   327  	}
   328  	for i, rp := range ap.Result {
   329  		if rp.Rel == 0 {
   330  			// rbat.Vecs[i] = bat.Vecs[rp.Pos]
   331  			// bat.Vecs[rp.Pos] = nil
   332  			typ := *bat.Vecs[rp.Pos].GetType()
   333  			ctr.rbat.Vecs[i] = proc.GetVector(typ)
   334  			if err := vector.GetUnionAllFunction(typ, proc.Mp())(ctr.rbat.Vecs[i], bat.Vecs[rp.Pos]); err != nil {
   335  				return err
   336  			}
   337  		}
   338  	}
   339  	ctr.rbat.SetRowCount(ctr.rbat.RowCount() + bat.RowCount())
   340  	anal.Output(ctr.rbat, isLast)
   341  	result.Batch = ctr.rbat
   342  	return nil
   343  }
   344  
   345  func (ctr *container) evalJoinCondition(bat *batch.Batch, proc *process.Process) error {
   346  	for i := range ctr.evecs {
   347  		vec, err := ctr.evecs[i].executor.Eval(proc, []*batch.Batch{bat})
   348  		if err != nil {
   349  			return err
   350  		}
   351  		ctr.vecs[i] = vec
   352  		ctr.evecs[i].vec = vec
   353  	}
   354  	return nil
   355  }