github.com/matrixorigin/matrixone@v0.7.0/pkg/sql/colexec/mark/join.go (about)

     1  // Copyright 2021 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package mark
    16  
    17  import (
    18  	"bytes"
    19  	"time"
    20  
    21  	"github.com/matrixorigin/matrixone/pkg/common/hashmap"
    22  	"github.com/matrixorigin/matrixone/pkg/container/batch"
    23  	"github.com/matrixorigin/matrixone/pkg/container/nulls"
    24  	"github.com/matrixorigin/matrixone/pkg/container/types"
    25  	"github.com/matrixorigin/matrixone/pkg/container/vector"
    26  	"github.com/matrixorigin/matrixone/pkg/sql/colexec"
    27  	"github.com/matrixorigin/matrixone/pkg/sql/plan"
    28  	"github.com/matrixorigin/matrixone/pkg/vm/process"
    29  )
    30  
    31  func String(_ any, buf *bytes.Buffer) {
    32  	buf.WriteString(" mark join ")
    33  }
    34  
    35  func Prepare(proc *process.Process, arg any) error {
    36  	ap := arg.(*Argument)
    37  	ap.ctr = new(container)
    38  	ap.ctr.inBuckets = make([]uint8, hashmap.UnitLimit)
    39  	ap.ctr.evecs = make([]evalVector, len(ap.Conditions[0]))
    40  	ap.ctr.vecs = make([]*vector.Vector, len(ap.Conditions[0]))
    41  	ap.ctr.bat = batch.NewWithSize(len(ap.Typs))
    42  	ap.ctr.bat.Zs = proc.Mp().GetSels()
    43  	for i, typ := range ap.Typs {
    44  		ap.ctr.bat.Vecs[i] = vector.New(typ)
    45  	}
    46  
    47  	ap.ctr.buildEqVec = make([]*vector.Vector, len(ap.Conditions[1]))
    48  	ap.ctr.buildEqEvecs = make([]evalVector, len(ap.Conditions[1]))
    49  	return nil
    50  }
    51  
    52  // Note: before mark join, right table has been used in hashbuild operator to build JoinMap, which only contains those tuples without null
    53  // the idx of tuples contains null is stored in nullSels
    54  
    55  // 1. for each tuple in left table, join with tuple(s) in right table based on Three-valued logic. Conditions may contain equal conditions and non-equal conditions
    56  // logic state for same row is Three-valued AND, for different rows is Three-valued OR
    57  
    58  // 2.1 if a probe tuple has null(i.e. zvals[k] == 0)
    59  //       scan whole right table directly and join with each tuple to determine state
    60  
    61  // 2.2 if a probe tuple has no null. then scan JoinMap firstly to check equal condtions.(condEq)
    62  //	    2.2.1 if condEq is condtrue in JoinMap(i.e. vals[k] > 0)
    63  //	 		    further check non-eq condtions in those tupe IN JoinMap
    64  //				2.2.1.1 if condNonEq is condTrue
    65  //						   mark as condTrue
    66  //	 	        2.2.1.2 if condNonEq is condUnkown
    67  //						   mark as condUnkown
    68  //	 	        2.2.1.3 if condNonEq is condFalse in JoinMap
    69  //						   further check eq and non-eq conds IN nullSels
    70  //                         (probe state could still be unknown BUT NOT FALSE as long as one unknown state exists, so have to scan the whole right table)
    71  
    72  //	    2.2.2 if condEq is condFalse in JoinMap
    73  //				check eq and non-eq conds in nullSels to determine condState. (same as 2.2.1.3)
    74  
    75  func Call(idx int, proc *process.Process, arg any, isFirst bool, isLast bool) (bool, error) {
    76  	anal := proc.GetAnalyze(idx)
    77  	anal.Start()
    78  	defer anal.Stop()
    79  	ap := arg.(*Argument)
    80  	ctr := ap.ctr
    81  	for {
    82  		switch ctr.state {
    83  		case Build:
    84  			if err := ctr.build(ap, proc, anal); err != nil {
    85  				return false, err
    86  			}
    87  			ctr.state = Probe
    88  
    89  		case Probe:
    90  			start := time.Now()
    91  			bat := <-proc.Reg.MergeReceivers[0].Ch
    92  			anal.WaitStop(start)
    93  
    94  			if bat == nil {
    95  				ctr.state = End
    96  				continue
    97  			}
    98  			if bat.Length() == 0 {
    99  				continue
   100  			}
   101  			if ctr.bat == nil || ctr.bat.Length() == 0 {
   102  				if err := ctr.emptyProbe(bat, ap, proc, anal, isFirst, isLast); err != nil {
   103  					ap.Free(proc, true)
   104  					return true, err
   105  				}
   106  			} else {
   107  				if err := ctr.probe(bat, ap, proc, anal, isFirst, isLast); err != nil {
   108  					ap.Free(proc, true)
   109  					return true, err
   110  				}
   111  			}
   112  			return false, nil
   113  
   114  		default:
   115  			ap.Free(proc, false)
   116  			proc.SetInputBatch(nil)
   117  			return true, nil
   118  		}
   119  	}
   120  }
   121  
   122  func (ctr *container) build(ap *Argument, proc *process.Process, anal process.Analyze) error {
   123  	start := time.Now()
   124  	bat := <-proc.Reg.MergeReceivers[1].Ch
   125  	anal.WaitStop(start)
   126  
   127  	if bat != nil {
   128  		var err error
   129  		joinMap := bat.Ht.(*hashmap.JoinMap)
   130  		ctr.evalNullSels(bat)
   131  		ctr.nullWithBatch, err = DumpBatch(bat, proc, ctr.nullSels)
   132  		if err != nil {
   133  			return err
   134  		}
   135  		if err = ctr.evalJoinBuildCondition(bat, ap.Conditions[1], proc); err != nil {
   136  			return err
   137  		}
   138  		ctr.rewriteCond = colexec.RewriteFilterExprList(ap.OnList)
   139  		ctr.bat = bat
   140  		ctr.mp = joinMap.Dup()
   141  		//ctr.bat = bat
   142  		//ctr.mp = bat.Ht.(*hashmap.JoinMap).Dup()
   143  		//anal.Alloc(ctr.mp.Map().Size())
   144  	}
   145  	return nil
   146  }
   147  
   148  func (ctr *container) emptyProbe(bat *batch.Batch, ap *Argument, proc *process.Process, anal process.Analyze, isFirst bool, isLast bool) error {
   149  	defer bat.Clean(proc.Mp())
   150  	anal.Input(bat, isFirst)
   151  	rbat := batch.NewWithSize(len(ap.Result))
   152  	count := bat.Length()
   153  	for i, rp := range ap.Result {
   154  		if rp >= 0 {
   155  			rbat.Vecs[i] = bat.Vecs[rp]
   156  			bat.Vecs[rp] = nil
   157  		} else {
   158  			rbat.Vecs[i] = vector.NewConstFixed(types.T_bool.ToType(), count, false, proc.Mp())
   159  		}
   160  	}
   161  	rbat.Zs = bat.Zs
   162  	bat.Zs = nil
   163  	anal.Output(rbat, isLast)
   164  	proc.SetInputBatch(rbat)
   165  	return nil
   166  }
   167  
   168  func (ctr *container) probe(bat *batch.Batch, ap *Argument, proc *process.Process, anal process.Analyze, isFirst bool, isLast bool) error {
   169  	defer bat.Clean(proc.Mp())
   170  	anal.Input(bat, isFirst)
   171  	rbat := batch.NewWithSize(len(ap.Result))
   172  	ctr.markVals = make([]bool, bat.Length())
   173  	ctr.markNulls = nulls.NewWithSize(bat.Length())
   174  	ctr.cleanEvalVectors(proc.Mp())
   175  	if err := ctr.evalJoinProbeCondition(bat, ap.Conditions[0], proc, anal); err != nil {
   176  		rbat.Clean(proc.Mp())
   177  		return err
   178  	}
   179  
   180  	count := bat.Length()
   181  	mSels := ctr.mp.Sels()
   182  	itr := ctr.mp.Map().NewIterator()
   183  	for i := 0; i < count; i += hashmap.UnitLimit {
   184  		n := count - i
   185  		if n > hashmap.UnitLimit {
   186  			n = hashmap.UnitLimit
   187  		}
   188  		copy(ctr.inBuckets, hashmap.OneUInt8s)
   189  		vals, zvals := itr.Find(i, n, ctr.vecs, ctr.inBuckets)
   190  		var condState resultType
   191  		// var condNonEq resultType
   192  		// var condEq resultType
   193  		var err error
   194  		for k := 0; k < n; k++ {
   195  			if ctr.inBuckets[k] == 0 {
   196  				continue
   197  			}
   198  			if zvals[k] == 0 { // 2.1 : probe tuple has null
   199  				condState, err = ctr.EvalEntire(bat, ctr.bat, i+k, proc, ctr.rewriteCond)
   200  				if err != nil {
   201  					rbat.Clean(proc.Mp())
   202  					return err
   203  				}
   204  				ctr.handleResultType(i+k, condState)
   205  			} else if vals[k] > 0 { // 2.2.1 : condEq is condTrue in JoinMap
   206  				condState, err = ctr.nonEqJoinInMap(ap, mSels, vals, k, i, proc, bat)
   207  				if err != nil {
   208  					rbat.Clean(proc.Mp())
   209  					return err
   210  				}
   211  				if condState == condTrue { // 2.2.1.1 : condNonEq is condTrue in JoinMap
   212  					ctr.markVals[i+k] = true
   213  				} else if condState == condUnkown { // 2.2.1.2 : condNonEq is condUnkown in JoinMap
   214  					nulls.Add(ctr.markNulls, uint64(i+k))
   215  				} else { // 2.2.1.3 : condNonEq is condFalse in JoinMap, further check in nullSels
   216  					if len(ctr.nullSels) == 0 {
   217  						ctr.handleResultType(i+k, condFalse)
   218  						continue
   219  					}
   220  					condState, err = ctr.EvalEntire(bat, ctr.nullWithBatch, i+k, proc, ctr.rewriteCond)
   221  					if err != nil {
   222  						rbat.Clean(proc.Mp())
   223  						return err
   224  					}
   225  					ctr.handleResultType(i+k, condState)
   226  				}
   227  			} else { // 2.2.2 : condEq in condFalse in JoinMap, further check in nullSels
   228  				if len(ctr.nullSels) == 0 {
   229  					ctr.handleResultType(i+k, condFalse)
   230  					continue
   231  				}
   232  				condState, err = ctr.EvalEntire(bat, ctr.nullWithBatch, i+k, proc, ctr.rewriteCond)
   233  				if err != nil {
   234  					rbat.Clean(proc.Mp())
   235  					return err
   236  				}
   237  				ctr.handleResultType(i+k, condState)
   238  			}
   239  		}
   240  	}
   241  	for i, pos := range ap.Result {
   242  		if pos >= 0 {
   243  			rbat.Vecs[i] = bat.Vecs[pos]
   244  			bat.Vecs[pos] = nil
   245  		} else {
   246  			rbat.Vecs[i] = vector.NewWithFixed(types.T_bool.ToType(), ctr.markVals, ctr.markNulls, proc.Mp())
   247  		}
   248  	}
   249  	rbat.Zs = bat.Zs
   250  	bat.Zs = nil
   251  	//rbat.ExpandNulls()
   252  	anal.Output(rbat, isLast)
   253  	proc.SetInputBatch(rbat)
   254  	return nil
   255  }
   256  
   257  // store the results of the calculation on the probe side of the equation condition
   258  func (ctr *container) evalJoinProbeCondition(bat *batch.Batch, conds []*plan.Expr, proc *process.Process, analyze process.Analyze) error {
   259  	for i, cond := range conds {
   260  		vec, err := colexec.EvalExpr(bat, proc, cond)
   261  		if err != nil || vec.ConstExpand(false, proc.Mp()) == nil {
   262  			ctr.cleanEvalVectors(proc.Mp())
   263  			return err
   264  		}
   265  		ctr.vecs[i] = vec
   266  		ctr.evecs[i].vec = vec
   267  		ctr.evecs[i].needFree = true
   268  		for j := range bat.Vecs {
   269  			if bat.Vecs[j] == vec {
   270  				ctr.evecs[i].needFree = false
   271  				break
   272  			}
   273  		}
   274  		if ctr.evecs[i].needFree && vec != nil {
   275  			analyze.Alloc(int64(vec.Size()))
   276  		}
   277  	}
   278  	return nil
   279  }
   280  
   281  // store the results of the calculation on the build side of the equation condition
   282  func (ctr *container) evalJoinBuildCondition(bat *batch.Batch, conds []*plan.Expr, proc *process.Process) error {
   283  	for i, cond := range conds {
   284  		vec, err := colexec.EvalExpr(bat, proc, cond)
   285  		if err != nil || vec.ConstExpand(false, proc.Mp()) == nil {
   286  			ctr.cleanEvalVectors(proc.Mp())
   287  			return err
   288  		}
   289  		ctr.buildEqVec[i] = vec
   290  		ctr.buildEqEvecs[i].vec = vec
   291  		ctr.buildEqEvecs[i].needFree = true
   292  		for j := range bat.Vecs {
   293  			if bat.Vecs[j] == vec {
   294  				ctr.buildEqEvecs[i].needFree = false
   295  				break
   296  			}
   297  		}
   298  	}
   299  	return nil
   300  }
   301  
   302  // calculate the state of non-equal conditions for those tuples in JoinMap
   303  func (ctr *container) nonEqJoinInMap(ap *Argument, mSels [][]int32, vals []uint64, k int, i int, proc *process.Process, bat *batch.Batch) (resultType, error) {
   304  	if ap.Cond != nil {
   305  		condState := condFalse
   306  		sels := mSels[vals[k]-1]
   307  		for _, sel := range sels {
   308  			vec, err := colexec.JoinFilterEvalExprInBucket(bat, ctr.bat, i+k, int(sel), proc, ap.Cond)
   309  			if err != nil {
   310  				return condUnkown, err
   311  			}
   312  			if vec.Nsp.Contains(0) {
   313  				condState = condUnkown
   314  			}
   315  			bs := vec.Col.([]bool)
   316  			if bs[0] {
   317  				condState = condTrue
   318  				vec.Free(proc.Mp())
   319  				break
   320  			}
   321  			vec.Free(proc.Mp())
   322  		}
   323  		return condState, nil
   324  	} else {
   325  		return condTrue, nil
   326  	}
   327  }
   328  
   329  func (ctr *container) EvalEntire(pbat, bat *batch.Batch, idx int, proc *process.Process, cond *plan.Expr) (resultType, error) {
   330  	if cond == nil {
   331  		return condTrue, nil
   332  	}
   333  	vec, err := colexec.JoinFilterEvalExpr(pbat, bat, idx, proc, cond)
   334  	defer vec.Free(proc.Mp())
   335  	if err != nil {
   336  		return condUnkown, err
   337  	}
   338  	bs := vec.Col.([]bool)
   339  	for _, b := range bs {
   340  		if b {
   341  			return condTrue, nil
   342  		}
   343  	}
   344  	if nulls.Any(vec.Nsp) {
   345  		return condUnkown, nil
   346  	}
   347  	return condFalse, nil
   348  }
   349  
   350  // collect the idx of tuple which contains null values
   351  func (ctr *container) evalNullSels(bat *batch.Batch) {
   352  	joinMap := bat.Ht.(*hashmap.JoinMap)
   353  	jmSels := joinMap.Sels()
   354  	selsMap := make(map[int32]bool)
   355  	for _, sel := range jmSels {
   356  		for _, i := range sel {
   357  			selsMap[i] = true
   358  		}
   359  	}
   360  	var nullSels []int64
   361  	for i := 0; i < bat.Length(); i++ {
   362  		if selsMap[int32(i)] {
   363  			ctr.sels = append(ctr.sels, int64(i))
   364  			continue
   365  		}
   366  		nullSels = append(nullSels, int64(i))
   367  	}
   368  	ctr.nullSels = nullSels
   369  }
   370  
   371  // mark probe tuple state
   372  func (ctr *container) handleResultType(idx int, r resultType) {
   373  	switch r {
   374  	case condTrue:
   375  		ctr.markVals[idx] = true
   376  	case condFalse:
   377  		ctr.markVals[idx] = false
   378  	case condUnkown:
   379  		nulls.Add(ctr.markNulls, uint64(idx))
   380  	}
   381  }
   382  
   383  func DumpBatch(originBatch *batch.Batch, proc *process.Process, sels []int64) (*batch.Batch, error) {
   384  	length := originBatch.Length()
   385  	flags := make([]uint8, length)
   386  	for _, sel := range sels {
   387  		flags[sel] = 1
   388  	}
   389  	bat := batch.NewWithSize(len(originBatch.Vecs))
   390  	for i, vec := range originBatch.Vecs {
   391  		bat.Vecs[i] = vector.New(vec.GetType())
   392  	}
   393  	if len(sels) == 0 {
   394  		return bat, nil
   395  	}
   396  	for i, vec := range originBatch.Vecs {
   397  		err := vector.UnionBatch(bat.Vecs[i], vec, 0, length, flags, proc.Mp())
   398  		if err != nil {
   399  			return nil, err
   400  		}
   401  	}
   402  
   403  	bat.Zs = append(bat.Zs, originBatch.Zs...)
   404  	return bat, nil
   405  }