github.com/matrixorigin/matrixone@v0.7.0/pkg/sql/colexec/semi/join.go (about)

     1  // Copyright 2021 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package semi
    16  
    17  import (
    18  	"bytes"
    19  	"time"
    20  
    21  	"github.com/matrixorigin/matrixone/pkg/common/hashmap"
    22  	"github.com/matrixorigin/matrixone/pkg/container/batch"
    23  	"github.com/matrixorigin/matrixone/pkg/container/vector"
    24  	"github.com/matrixorigin/matrixone/pkg/sql/colexec"
    25  	"github.com/matrixorigin/matrixone/pkg/sql/plan"
    26  	"github.com/matrixorigin/matrixone/pkg/vm/process"
    27  )
    28  
    29  func String(_ any, buf *bytes.Buffer) {
    30  	buf.WriteString(" semi join ")
    31  }
    32  
    33  func Prepare(proc *process.Process, arg any) error {
    34  	ap := arg.(*Argument)
    35  	ap.ctr = new(container)
    36  	ap.ctr.inBuckets = make([]uint8, hashmap.UnitLimit)
    37  	ap.ctr.evecs = make([]evalVector, len(ap.Conditions[0]))
    38  	ap.ctr.vecs = make([]*vector.Vector, len(ap.Conditions[0]))
    39  	return nil
    40  }
    41  
    42  func Call(idx int, proc *process.Process, arg any, isFirst bool, isLast bool) (bool, error) {
    43  	anal := proc.GetAnalyze(idx)
    44  	anal.Start()
    45  	defer anal.Stop()
    46  	ap := arg.(*Argument)
    47  	ctr := ap.ctr
    48  	for {
    49  		switch ctr.state {
    50  		case Build:
    51  			if err := ctr.build(ap, proc, anal); err != nil {
    52  				ap.Free(proc, true)
    53  				return false, err
    54  			}
    55  			ctr.state = Probe
    56  
    57  		case Probe:
    58  			start := time.Now()
    59  			bat := <-proc.Reg.MergeReceivers[0].Ch
    60  			anal.WaitStop(start)
    61  
    62  			if bat == nil {
    63  				ctr.state = End
    64  				continue
    65  			}
    66  			if bat.Length() == 0 {
    67  				continue
    68  			}
    69  			if ctr.bat == nil || ctr.bat.Length() == 0 {
    70  				bat.Clean(proc.Mp())
    71  				continue
    72  			}
    73  			if err := ctr.probe(bat, ap, proc, anal, isFirst, isLast); err != nil {
    74  				ap.Free(proc, true)
    75  				return false, err
    76  			}
    77  			return false, nil
    78  
    79  		default:
    80  			ap.Free(proc, false)
    81  			proc.SetInputBatch(nil)
    82  			return true, nil
    83  		}
    84  	}
    85  }
    86  
    87  func (ctr *container) build(ap *Argument, proc *process.Process, anal process.Analyze) error {
    88  	start := time.Now()
    89  	bat := <-proc.Reg.MergeReceivers[1].Ch
    90  	anal.WaitStop(start)
    91  
    92  	if bat != nil {
    93  		ctr.bat = bat
    94  		ctr.mp = bat.Ht.(*hashmap.JoinMap).Dup()
    95  		anal.Alloc(ctr.mp.Map().Size())
    96  	}
    97  	return nil
    98  }
    99  
   100  func (ctr *container) probe(bat *batch.Batch, ap *Argument, proc *process.Process, anal process.Analyze, isFirst bool, isLast bool) error {
   101  	defer bat.Clean(proc.Mp())
   102  	anal.Input(bat, isFirst)
   103  	rbat := batch.NewWithSize(len(ap.Result))
   104  	rbat.Zs = proc.Mp().GetSels()
   105  	for i, pos := range ap.Result {
   106  		rbat.Vecs[i] = vector.New(bat.Vecs[pos].Typ)
   107  	}
   108  	ctr.cleanEvalVectors(proc.Mp())
   109  	if err := ctr.evalJoinCondition(bat, ap.Conditions[0], proc, anal); err != nil {
   110  		rbat.Clean(proc.Mp())
   111  		return err
   112  	}
   113  	count := bat.Length()
   114  	mSels := ctr.mp.Sels()
   115  	itr := ctr.mp.Map().NewIterator()
   116  	eligible := make([]int64, 0, hashmap.UnitLimit)
   117  	for i := 0; i < count; i += hashmap.UnitLimit {
   118  		n := count - i
   119  		if n > hashmap.UnitLimit {
   120  			n = hashmap.UnitLimit
   121  		}
   122  		copy(ctr.inBuckets, hashmap.OneUInt8s)
   123  		vals, zvals := itr.Find(i, n, ctr.vecs, ctr.inBuckets)
   124  		for k := 0; k < n; k++ {
   125  			if ctr.inBuckets[k] == 0 || zvals[k] == 0 || vals[k] == 0 {
   126  				continue
   127  			}
   128  			if ap.Cond != nil {
   129  				matched := false // mark if any tuple satisfies the condition
   130  				sels := mSels[vals[k]-1]
   131  				for _, sel := range sels {
   132  					vec, err := colexec.JoinFilterEvalExprInBucket(bat, ctr.bat, i+k, int(sel), proc, ap.Cond)
   133  					if err != nil {
   134  						rbat.Clean(proc.Mp())
   135  						return err
   136  					}
   137  					bs := vec.Col.([]bool)
   138  					if bs[0] {
   139  						matched = true
   140  						vec.Free(proc.Mp())
   141  						break
   142  					}
   143  					vec.Free(proc.Mp())
   144  				}
   145  				if !matched {
   146  					continue
   147  				}
   148  			}
   149  			eligible = append(eligible, int64(i+k))
   150  			rbat.Zs = append(rbat.Zs, bat.Zs[i+k])
   151  		}
   152  		for j, pos := range ap.Result {
   153  			if err := vector.Union(rbat.Vecs[j], bat.Vecs[pos], eligible, true, proc.Mp()); err != nil {
   154  				rbat.Clean(proc.Mp())
   155  				return err
   156  			}
   157  		}
   158  		eligible = eligible[:0]
   159  	}
   160  	rbat.ExpandNulls()
   161  	anal.Output(rbat, isLast)
   162  	proc.SetInputBatch(rbat)
   163  	return nil
   164  }
   165  
   166  func (ctr *container) evalJoinCondition(bat *batch.Batch, conds []*plan.Expr, proc *process.Process, analyze process.Analyze) error {
   167  	for i, cond := range conds {
   168  		vec, err := colexec.EvalExpr(bat, proc, cond)
   169  		if err != nil || vec.ConstExpand(false, proc.Mp()) == nil {
   170  			ctr.cleanEvalVectors(proc.Mp())
   171  			return err
   172  		}
   173  		ctr.vecs[i] = vec
   174  		ctr.evecs[i].vec = vec
   175  		ctr.evecs[i].needFree = true
   176  		for j := range bat.Vecs {
   177  			if bat.Vecs[j] == vec {
   178  				ctr.evecs[i].needFree = false
   179  				break
   180  			}
   181  		}
   182  		if ctr.evecs[i].needFree && vec != nil {
   183  			analyze.Alloc(int64(vec.Size()))
   184  		}
   185  	}
   186  	return nil
   187  }