github.com/matrixorigin/matrixone@v0.7.0/pkg/sql/colexec/join/join.go (about)

     1  // Copyright 2021 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package join
    16  
    17  import (
    18  	"bytes"
    19  	"time"
    20  
    21  	"github.com/matrixorigin/matrixone/pkg/common/hashmap"
    22  	"github.com/matrixorigin/matrixone/pkg/container/batch"
    23  	"github.com/matrixorigin/matrixone/pkg/container/vector"
    24  	"github.com/matrixorigin/matrixone/pkg/sql/colexec"
    25  	"github.com/matrixorigin/matrixone/pkg/sql/plan"
    26  	"github.com/matrixorigin/matrixone/pkg/vm/process"
    27  )
    28  
    29  func String(_ any, buf *bytes.Buffer) {
    30  	buf.WriteString(" inner join ")
    31  }
    32  
    33  func Prepare(proc *process.Process, arg any) error {
    34  	ap := arg.(*Argument)
    35  	ap.ctr = new(container)
    36  	ap.ctr.inBuckets = make([]uint8, hashmap.UnitLimit)
    37  	ap.ctr.evecs = make([]evalVector, len(ap.Conditions[0]))
    38  	ap.ctr.vecs = make([]*vector.Vector, len(ap.Conditions[0]))
    39  	return nil
    40  }
    41  
    42  func Call(idx int, proc *process.Process, arg any, isFirst bool, isLast bool) (bool, error) {
    43  	anal := proc.GetAnalyze(idx)
    44  	anal.Start()
    45  	defer anal.Stop()
    46  	ap := arg.(*Argument)
    47  	ctr := ap.ctr
    48  	for {
    49  		switch ctr.state {
    50  		case Build:
    51  			if err := ctr.build(ap, proc, anal); err != nil {
    52  				ap.Free(proc, true)
    53  				return false, err
    54  			}
    55  			ctr.state = Probe
    56  		case Probe:
    57  			start := time.Now()
    58  			bat := <-proc.Reg.MergeReceivers[0].Ch
    59  			anal.WaitStop(start)
    60  
    61  			if bat == nil {
    62  				ctr.state = End
    63  				continue
    64  			}
    65  			if bat.Length() == 0 {
    66  				continue
    67  			}
    68  			if ctr.bat == nil || ctr.bat.Length() == 0 {
    69  				bat.Clean(proc.Mp())
    70  				continue
    71  			}
    72  			if err := ctr.probe(bat, ap, proc, anal, isFirst, isLast); err != nil {
    73  				ap.Free(proc, true)
    74  				return false, err
    75  			}
    76  			return false, nil
    77  
    78  		default:
    79  			ap.Free(proc, false)
    80  			proc.SetInputBatch(nil)
    81  			return true, nil
    82  		}
    83  	}
    84  }
    85  
    86  func (ctr *container) build(ap *Argument, proc *process.Process, anal process.Analyze) error {
    87  	start := time.Now()
    88  	bat := <-proc.Reg.MergeReceivers[1].Ch
    89  	anal.WaitStop(start)
    90  
    91  	if bat != nil {
    92  		ctr.bat = bat
    93  		ctr.mp = bat.Ht.(*hashmap.JoinMap).Dup()
    94  		anal.Alloc(ctr.mp.Map().Size())
    95  	}
    96  	return nil
    97  }
    98  
    99  func (ctr *container) probe(bat *batch.Batch, ap *Argument, proc *process.Process, anal process.Analyze, isFirst bool, isLast bool) error {
   100  	defer bat.Clean(proc.Mp())
   101  	anal.Input(bat, isFirst)
   102  	rbat := batch.NewWithSize(len(ap.Result))
   103  	rbat.Zs = proc.Mp().GetSels()
   104  	for i, rp := range ap.Result {
   105  		if rp.Rel == 0 {
   106  			rbat.Vecs[i] = vector.New(bat.Vecs[rp.Pos].Typ)
   107  		} else {
   108  			rbat.Vecs[i] = vector.New(ctr.bat.Vecs[rp.Pos].Typ)
   109  		}
   110  	}
   111  
   112  	idxFlg := false
   113  	ctr.cleanEvalVectors(proc.Mp())
   114  	if err := ctr.evalJoinCondition(bat, ap.Conditions[0], proc, &idxFlg, anal); err != nil {
   115  		rbat.Clean(proc.Mp())
   116  		return err
   117  	}
   118  
   119  	mSels := ctr.mp.Sels()
   120  
   121  	count := bat.Length()
   122  	itr := ctr.mp.Map().NewIterator()
   123  	for i := 0; i < count; i += hashmap.UnitLimit {
   124  		n := count - i
   125  		if n > hashmap.UnitLimit {
   126  			n = hashmap.UnitLimit
   127  		}
   128  		copy(ctr.inBuckets, hashmap.OneUInt8s)
   129  		vals, zvals := itr.Find(i, n, ctr.vecs, ctr.inBuckets)
   130  		for k := 0; k < n; k++ {
   131  			if ctr.inBuckets[k] == 0 || zvals[k] == 0 || vals[k] == 0 {
   132  				continue
   133  			}
   134  			sels := mSels[vals[k]-1]
   135  			if ap.Cond != nil {
   136  				for _, sel := range sels {
   137  					vec, err := colexec.JoinFilterEvalExprInBucket(bat, ctr.bat, i+k, int(sel), proc, ap.Cond)
   138  					if err != nil {
   139  						rbat.Clean(proc.Mp())
   140  						return err
   141  					}
   142  					bs := vec.Col.([]bool)
   143  					if !bs[0] {
   144  						vec.Free(proc.Mp())
   145  						continue
   146  					}
   147  					vec.Free(proc.Mp())
   148  					for j, rp := range ap.Result {
   149  						if rp.Rel == 0 {
   150  							if err := vector.UnionOne(rbat.Vecs[j], bat.Vecs[rp.Pos], int64(i+k), proc.Mp()); err != nil {
   151  								rbat.Clean(proc.Mp())
   152  								return err
   153  							}
   154  						} else {
   155  							if err := vector.UnionOne(rbat.Vecs[j], ctr.bat.Vecs[rp.Pos], int64(sel), proc.Mp()); err != nil {
   156  								rbat.Clean(proc.Mp())
   157  								return err
   158  							}
   159  						}
   160  					}
   161  					rbat.Zs = append(rbat.Zs, ctr.bat.Zs[sel])
   162  				}
   163  			} else {
   164  				for j, rp := range ap.Result {
   165  					if rp.Rel == 0 {
   166  						if err := vector.UnionMulti(rbat.Vecs[j], bat.Vecs[rp.Pos], int64(i+k), len(sels), proc.Mp()); err != nil {
   167  							rbat.Clean(proc.Mp())
   168  							return err
   169  						}
   170  					} else {
   171  						for _, sel := range sels {
   172  							if err := vector.UnionOne(rbat.Vecs[j], ctr.bat.Vecs[rp.Pos], int64(sel), proc.Mp()); err != nil {
   173  								rbat.Clean(proc.Mp())
   174  								return err
   175  							}
   176  						}
   177  					}
   178  				}
   179  				for _, sel := range sels {
   180  					rbat.Zs = append(rbat.Zs, ctr.bat.Zs[sel])
   181  				}
   182  			}
   183  		}
   184  	}
   185  	rbat.ExpandNulls()
   186  	anal.Output(rbat, isLast)
   187  	proc.SetInputBatch(rbat)
   188  	return nil
   189  }
   190  
   191  func (ctr *container) evalJoinCondition(bat *batch.Batch, conds []*plan.Expr, proc *process.Process, flg *bool, analyze process.Analyze) error {
   192  	for i, cond := range conds {
   193  		vec, err := colexec.EvalExpr(bat, proc, cond)
   194  		if err != nil || vec.ConstExpand(false, proc.Mp()) == nil {
   195  			ctr.cleanEvalVectors(proc.Mp())
   196  			return err
   197  		}
   198  		ctr.vecs[i] = vec
   199  		ctr.evecs[i].vec = vec
   200  		ctr.evecs[i].needFree = true
   201  		for j := range bat.Vecs {
   202  			if bat.Vecs[j] == vec {
   203  				ctr.evecs[i].needFree = false
   204  				break
   205  			}
   206  		}
   207  		if ctr.evecs[i].needFree && vec != nil {
   208  			analyze.Alloc(int64(vec.Size()))
   209  		}
   210  	}
   211  	return nil
   212  }