github.com/matrixorigin/matrixone@v0.7.0/pkg/sql/colexec/single/join.go (about)

     1  // Copyright 2021 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package single
    16  
    17  import (
    18  	"bytes"
    19  	"time"
    20  
    21  	"github.com/matrixorigin/matrixone/pkg/common/hashmap"
    22  	"github.com/matrixorigin/matrixone/pkg/common/moerr"
    23  	"github.com/matrixorigin/matrixone/pkg/container/batch"
    24  	"github.com/matrixorigin/matrixone/pkg/container/vector"
    25  	"github.com/matrixorigin/matrixone/pkg/sql/colexec"
    26  	"github.com/matrixorigin/matrixone/pkg/sql/plan"
    27  	"github.com/matrixorigin/matrixone/pkg/vm/process"
    28  )
    29  
    30  func String(_ any, buf *bytes.Buffer) {
    31  	buf.WriteString(" single join ")
    32  }
    33  
    34  func Prepare(proc *process.Process, arg any) error {
    35  	ap := arg.(*Argument)
    36  	ap.ctr = new(container)
    37  	ap.ctr.inBuckets = make([]uint8, hashmap.UnitLimit)
    38  	ap.ctr.evecs = make([]evalVector, len(ap.Conditions[0]))
    39  	ap.ctr.vecs = make([]*vector.Vector, len(ap.Conditions[0]))
    40  	ap.ctr.bat = batch.NewWithSize(len(ap.Typs))
    41  	ap.ctr.bat.Zs = proc.Mp().GetSels()
    42  	for i, typ := range ap.Typs {
    43  		ap.ctr.bat.Vecs[i] = vector.New(typ)
    44  	}
    45  	return nil
    46  }
    47  
    48  func Call(idx int, proc *process.Process, arg any, isFirst bool, isLast bool) (bool, error) {
    49  	anal := proc.GetAnalyze(idx)
    50  	anal.Start()
    51  	defer anal.Stop()
    52  	ap := arg.(*Argument)
    53  	ctr := ap.ctr
    54  	for {
    55  		switch ctr.state {
    56  		case Build:
    57  			if err := ctr.build(ap, proc, anal); err != nil {
    58  				ap.Free(proc, true)
    59  				return false, err
    60  			}
    61  			ctr.state = Probe
    62  
    63  		case Probe:
    64  			start := time.Now()
    65  			bat := <-proc.Reg.MergeReceivers[0].Ch
    66  			anal.WaitStop(start)
    67  
    68  			if bat == nil {
    69  				ctr.state = End
    70  				continue
    71  			}
    72  			if bat.Length() == 0 {
    73  				continue
    74  			}
    75  			if ctr.bat.Length() == 0 {
    76  				if err := ctr.emptyProbe(bat, ap, proc, anal, isFirst, isLast); err != nil {
    77  					ap.Free(proc, true)
    78  					return true, err
    79  				}
    80  			} else {
    81  				if err := ctr.probe(bat, ap, proc, anal, isFirst, isLast); err != nil {
    82  					ap.Free(proc, true)
    83  					return true, err
    84  				}
    85  			}
    86  			return false, nil
    87  
    88  		default:
    89  			ap.Free(proc, false)
    90  			proc.SetInputBatch(nil)
    91  			return true, nil
    92  		}
    93  	}
    94  }
    95  
    96  func (ctr *container) build(ap *Argument, proc *process.Process, anal process.Analyze) error {
    97  	start := time.Now()
    98  	bat := <-proc.Reg.MergeReceivers[1].Ch
    99  	anal.WaitStop(start)
   100  
   101  	if bat != nil {
   102  		ctr.bat = bat
   103  		ctr.mp = bat.Ht.(*hashmap.JoinMap).Dup()
   104  		anal.Alloc(ctr.mp.Map().Size())
   105  	}
   106  	return nil
   107  }
   108  
   109  func (ctr *container) emptyProbe(bat *batch.Batch, ap *Argument, proc *process.Process, anal process.Analyze, isFirst bool, isLast bool) error {
   110  	defer bat.Clean(proc.Mp())
   111  	anal.Input(bat, isFirst)
   112  	rbat := batch.NewWithSize(len(ap.Result))
   113  	count := bat.Length()
   114  	for i, rp := range ap.Result {
   115  		if rp.Rel == 0 {
   116  			rbat.Vecs[i] = bat.Vecs[rp.Pos]
   117  			bat.Vecs[rp.Pos] = nil
   118  		} else {
   119  			rbat.Vecs[i] = vector.NewConstNull(ctr.bat.Vecs[rp.Pos].Typ, count)
   120  		}
   121  	}
   122  	rbat.Zs = bat.Zs
   123  	bat.Zs = nil
   124  	anal.Output(rbat, isLast)
   125  	proc.SetInputBatch(rbat)
   126  	return nil
   127  }
   128  
   129  func (ctr *container) probe(bat *batch.Batch, ap *Argument, proc *process.Process, anal process.Analyze, isFirst bool, isLast bool) error {
   130  	defer bat.Clean(proc.Mp())
   131  	anal.Input(bat, isFirst)
   132  	rbat := batch.NewWithSize(len(ap.Result))
   133  	for i, rp := range ap.Result {
   134  		if rp.Rel != 0 {
   135  			rbat.Vecs[i] = vector.New(ctr.bat.Vecs[rp.Pos].Typ)
   136  		}
   137  	}
   138  	ctr.cleanEvalVectors(proc.Mp())
   139  	if err := ctr.evalJoinCondition(bat, ap.Conditions[0], proc, anal); err != nil {
   140  		rbat.Clean(proc.Mp())
   141  		return err
   142  	}
   143  
   144  	count := bat.Length()
   145  	mSels := ctr.mp.Sels()
   146  	itr := ctr.mp.Map().NewIterator()
   147  	for i := 0; i < count; i += hashmap.UnitLimit {
   148  		n := count - i
   149  		if n > hashmap.UnitLimit {
   150  			n = hashmap.UnitLimit
   151  		}
   152  		copy(ctr.inBuckets, hashmap.OneUInt8s)
   153  		vals, zvals := itr.Find(i, n, ctr.vecs, ctr.inBuckets)
   154  		for k := 0; k < n; k++ {
   155  			if ctr.inBuckets[k] == 0 {
   156  				continue
   157  			}
   158  			if zvals[k] == 0 || vals[k] == 0 {
   159  				for j, rp := range ap.Result {
   160  					if rp.Rel != 0 {
   161  						if err := vector.UnionNull(rbat.Vecs[j], ctr.bat.Vecs[rp.Pos], proc.Mp()); err != nil {
   162  							rbat.Clean(proc.Mp())
   163  							return err
   164  						}
   165  					}
   166  				}
   167  				continue
   168  			}
   169  			idx := 0
   170  			matched := false
   171  			sels := mSels[vals[k]-1]
   172  			if ap.Cond != nil {
   173  				for j, sel := range sels {
   174  					vec, err := colexec.JoinFilterEvalExprInBucket(bat, ctr.bat, i+k, int(sel), proc, ap.Cond)
   175  					if err != nil {
   176  						rbat.Clean(proc.Mp())
   177  						return err
   178  					}
   179  					bs := vec.Col.([]bool)
   180  					if bs[0] {
   181  						if matched {
   182  							vec.Free(proc.Mp())
   183  							rbat.Clean(proc.Mp())
   184  							return moerr.NewInternalError(proc.Ctx, "scalar subquery returns more than 1 row")
   185  						}
   186  						matched = true
   187  						idx = j
   188  					}
   189  					vec.Free(proc.Mp())
   190  				}
   191  			} else if len(sels) > 1 {
   192  				rbat.Clean(proc.Mp())
   193  				return moerr.NewInternalError(proc.Ctx, "scalar subquery returns more than 1 row")
   194  			}
   195  			if ap.Cond != nil && !matched {
   196  				for j, rp := range ap.Result {
   197  					if rp.Rel != 0 {
   198  						if err := vector.UnionNull(rbat.Vecs[j], ctr.bat.Vecs[rp.Pos], proc.Mp()); err != nil {
   199  							rbat.Clean(proc.Mp())
   200  							return err
   201  						}
   202  					}
   203  				}
   204  				continue
   205  			}
   206  			sel := sels[idx]
   207  			for j, rp := range ap.Result {
   208  				if rp.Rel != 0 {
   209  					if err := vector.UnionOne(rbat.Vecs[j], ctr.bat.Vecs[rp.Pos], int64(sel), proc.Mp()); err != nil {
   210  						rbat.Clean(proc.Mp())
   211  						return err
   212  					}
   213  				}
   214  			}
   215  		}
   216  	}
   217  	for i, rp := range ap.Result {
   218  		if rp.Rel == 0 {
   219  			rbat.Vecs[i] = bat.Vecs[rp.Pos]
   220  			bat.Vecs[rp.Pos] = nil
   221  		}
   222  	}
   223  	rbat.Zs = bat.Zs
   224  	bat.Zs = nil
   225  	rbat.ExpandNulls()
   226  	anal.Output(rbat, isLast)
   227  	proc.SetInputBatch(rbat)
   228  	return nil
   229  }
   230  
   231  func (ctr *container) evalJoinCondition(bat *batch.Batch, conds []*plan.Expr, proc *process.Process, analyze process.Analyze) error {
   232  	for i, cond := range conds {
   233  		vec, err := colexec.EvalExpr(bat, proc, cond)
   234  		if err != nil || vec.ConstExpand(false, proc.Mp()) == nil {
   235  			ctr.cleanEvalVectors(proc.Mp())
   236  			return err
   237  		}
   238  		ctr.vecs[i] = vec
   239  		ctr.evecs[i].vec = vec
   240  		ctr.evecs[i].needFree = true
   241  		for j := range bat.Vecs {
   242  			if bat.Vecs[j] == vec {
   243  				ctr.evecs[i].needFree = false
   244  				break
   245  			}
   246  		}
   247  		if ctr.evecs[i].needFree && vec != nil {
   248  			analyze.Alloc(int64(vec.Size()))
   249  		}
   250  	}
   251  	return nil
   252  }