github.com/matrixorigin/matrixone@v1.2.0/pkg/sql/colexec/mark/join.go (about)

     1  // Copyright 2021 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package mark
    16  
    17  import (
    18  	"bytes"
    19  
    20  	"github.com/matrixorigin/matrixone/pkg/common/hashmap"
    21  	"github.com/matrixorigin/matrixone/pkg/container/batch"
    22  	"github.com/matrixorigin/matrixone/pkg/container/nulls"
    23  	"github.com/matrixorigin/matrixone/pkg/container/types"
    24  	"github.com/matrixorigin/matrixone/pkg/container/vector"
    25  	"github.com/matrixorigin/matrixone/pkg/sql/colexec"
    26  	"github.com/matrixorigin/matrixone/pkg/sql/plan"
    27  	"github.com/matrixorigin/matrixone/pkg/vm"
    28  	"github.com/matrixorigin/matrixone/pkg/vm/process"
    29  )
    30  
    31  const argName = "mark_join"
    32  
    33  func (arg *Argument) String(buf *bytes.Buffer) {
    34  	buf.WriteString(argName)
    35  	buf.WriteString(": mark join ")
    36  }
    37  
    38  func (arg *Argument) Prepare(proc *process.Process) error {
    39  	var err error
    40  	ap := arg
    41  	ap.ctr = new(container)
    42  	ap.ctr.InitReceiver(proc, false)
    43  	ap.ctr.inBuckets = make([]uint8, hashmap.UnitLimit)
    44  	ap.ctr.evecs = make([]evalVector, len(ap.Conditions[0]))
    45  	ap.ctr.vecs = make([]*vector.Vector, len(ap.Conditions[0]))
    46  	ap.ctr.bat = batch.NewWithSize(len(ap.Typs))
    47  	for i, typ := range ap.Typs {
    48  		ap.ctr.bat.Vecs[i] = proc.GetVector(typ)
    49  	}
    50  
    51  	ap.ctr.buildEqVec = make([]*vector.Vector, len(ap.Conditions[1]))
    52  	ap.ctr.buildEqEvecs = make([]evalVector, len(ap.Conditions[1]))
    53  
    54  	if ap.Cond != nil {
    55  		ap.ctr.expr, err = colexec.NewExpressionExecutor(proc, ap.Cond)
    56  	}
    57  	return err
    58  }
    59  
    60  // Note: before mark join, right table has been used in hashbuild operator to build JoinMap, which only contains those tuples without null
    61  // the idx of tuples contains null is stored in nullSels
    62  
    63  // 1. for each tuple in left table, join with tuple(s) in right table based on Three-valued logic. Conditions may contain equal conditions and non-equal conditions
    64  // logic state for same row is Three-valued AND, for different rows is Three-valued OR
    65  
    66  // 2.1 if a probe tuple has null(i.e. zvals[k] == 0)
    67  //       scan whole right table directly and join with each tuple to determine state
    68  
    69  // 2.2 if a probe tuple has no null. then scan JoinMap firstly to check equal condtions.(condEq)
    70  //	    2.2.1 if condEq is condtrue in JoinMap(i.e. vals[k] > 0)
    71  //	 		    further check non-eq condtions in those tupe IN JoinMap
    72  //				2.2.1.1 if condNonEq is condTrue
    73  //						   mark as condTrue
    74  //	 	        2.2.1.2 if condNonEq is condUnkown
    75  //						   mark as condUnkown
    76  //	 	        2.2.1.3 if condNonEq is condFalse in JoinMap
    77  //						   further check eq and non-eq conds IN nullSels
    78  //                         (probe state could still be unknown BUT NOT FALSE as long as one unknown state exists, so have to scan the whole right table)
    79  
    80  //	    2.2.2 if condEq is condFalse in JoinMap
    81  //				check eq and non-eq conds in nullSels to determine condState. (same as 2.2.1.3)
    82  
    83  func (arg *Argument) Call(proc *process.Process) (vm.CallResult, error) {
    84  	if err, isCancel := vm.CancelCheck(proc); isCancel {
    85  		return vm.CancelResult, err
    86  	}
    87  
    88  	anal := proc.GetAnalyze(arg.GetIdx(), arg.GetParallelIdx(), arg.GetParallelMajor())
    89  	anal.Start()
    90  	defer anal.Stop()
    91  	ap := arg
    92  	ctr := ap.ctr
    93  	result := vm.NewCallResult()
    94  	for {
    95  		switch ctr.state {
    96  		case Build:
    97  			if err := ctr.build(ap, proc, anal); err != nil {
    98  				return result, err
    99  			}
   100  			ctr.state = Probe
   101  
   102  		case Probe:
   103  			bat, _, err := ctr.ReceiveFromSingleReg(0, anal)
   104  			if err != nil {
   105  				return result, err
   106  			}
   107  
   108  			if bat == nil {
   109  				ctr.state = End
   110  				continue
   111  			}
   112  			if bat.IsEmpty() {
   113  				proc.PutBatch(bat)
   114  				continue
   115  			}
   116  			if ctr.bat == nil || ctr.bat.RowCount() == 0 {
   117  				if err = ctr.emptyProbe(bat, ap, proc, anal, ap.GetIsFirst(), ap.GetIsLast(), &result); err != nil {
   118  					bat.Clean(proc.Mp())
   119  					result.Status = vm.ExecStop
   120  					return result, err
   121  				}
   122  			} else {
   123  				if err = ctr.probe(bat, ap, proc, anal, ap.GetIsFirst(), ap.GetIsLast(), &result); err != nil {
   124  					bat.Clean(proc.Mp())
   125  					result.Status = vm.ExecStop
   126  					return result, err
   127  				}
   128  			}
   129  			proc.PutBatch(bat)
   130  			return result, nil
   131  
   132  		default:
   133  			result.Batch = nil
   134  			result.Status = vm.ExecStop
   135  			return result, nil
   136  		}
   137  	}
   138  }
   139  
   140  func (ctr *container) receiveHashMap(anal process.Analyze) error {
   141  	bat, _, err := ctr.ReceiveFromSingleReg(1, anal)
   142  	if err != nil {
   143  		return err
   144  	}
   145  	if bat != nil && bat.AuxData != nil {
   146  		ctr.mp = bat.DupJmAuxData()
   147  		ctr.maxAllocSize = max(ctr.maxAllocSize, ctr.mp.Size())
   148  	}
   149  	return nil
   150  }
   151  
   152  func (ctr *container) receiveBatch(ap *Argument, proc *process.Process, anal process.Analyze) error {
   153  	bat, _, err := ctr.ReceiveFromSingleReg(1, anal)
   154  	if err != nil {
   155  		return err
   156  	}
   157  	if bat != nil {
   158  		ctr.evalNullSels(bat)
   159  		ctr.nullWithBatch, err = DumpBatch(bat, proc, ctr.nullSels)
   160  		if err != nil {
   161  			return err
   162  		}
   163  		if err = ctr.evalJoinBuildCondition(bat, proc); err != nil {
   164  			return err
   165  		}
   166  		ctr.rewriteCond = colexec.RewriteFilterExprList(ap.OnList)
   167  		if ctr.bat != nil {
   168  			proc.PutBatch(ctr.bat)
   169  			ctr.bat = nil
   170  		}
   171  		ctr.bat = bat
   172  	}
   173  	return nil
   174  }
   175  
   176  func (ctr *container) build(ap *Argument, proc *process.Process, anal process.Analyze) error {
   177  	err := ctr.receiveHashMap(anal)
   178  	if err != nil {
   179  		return err
   180  	}
   181  	return ctr.receiveBatch(ap, proc, anal)
   182  }
   183  
   184  func (ctr *container) emptyProbe(bat *batch.Batch, ap *Argument, proc *process.Process, anal process.Analyze, isFirst bool, isLast bool, result *vm.CallResult) (err error) {
   185  	anal.Input(bat, isFirst)
   186  	if ctr.rbat != nil {
   187  		proc.PutBatch(ctr.rbat)
   188  		ctr.rbat = nil
   189  	}
   190  	ctr.rbat = batch.NewWithSize(len(ap.Result))
   191  	count := bat.RowCount()
   192  	for i, rp := range ap.Result {
   193  		if rp >= 0 {
   194  			typ := *bat.Vecs[rp].GetType()
   195  			ctr.rbat.Vecs[i] = proc.GetVector(typ)
   196  			err = vector.GetUnionAllFunction(typ, proc.Mp())(ctr.rbat.Vecs[i], bat.Vecs[rp])
   197  		} else {
   198  			ctr.rbat.Vecs[i], err = vector.NewConstFixed(types.T_bool.ToType(), false, count, proc.Mp())
   199  		}
   200  		if err != nil {
   201  			return err
   202  		}
   203  	}
   204  	ctr.rbat.AddRowCount(bat.RowCount())
   205  	anal.Output(ctr.rbat, isLast)
   206  
   207  	result.Batch = ctr.rbat
   208  	return nil
   209  }
   210  
   211  func (ctr *container) probe(bat *batch.Batch, ap *Argument, proc *process.Process, anal process.Analyze, isFirst bool, isLast bool, result *vm.CallResult) error {
   212  	anal.Input(bat, isFirst)
   213  	if ctr.rbat != nil {
   214  		proc.PutBatch(ctr.rbat)
   215  		ctr.rbat = nil
   216  	}
   217  	ctr.rbat = batch.NewWithSize(len(ap.Result))
   218  	markVec, err := proc.AllocVectorOfRows(types.T_bool.ToType(), bat.RowCount(), nil)
   219  	if err != nil {
   220  		return err
   221  	}
   222  	ctr.markVals = vector.MustFixedCol[bool](markVec)
   223  	ctr.markNulls = nulls.NewWithSize(bat.RowCount())
   224  
   225  	if err = ctr.evalJoinProbeCondition(bat, proc); err != nil {
   226  		return err
   227  	}
   228  
   229  	count := bat.RowCount()
   230  	mSels := ctr.mp.Sels()
   231  	itr := ctr.mp.NewIterator()
   232  	for i := 0; i < count; i += hashmap.UnitLimit {
   233  		n := count - i
   234  		if n > hashmap.UnitLimit {
   235  			n = hashmap.UnitLimit
   236  		}
   237  		copy(ctr.inBuckets, hashmap.OneUInt8s)
   238  		vals, zvals := itr.Find(i, n, ctr.vecs, ctr.inBuckets)
   239  		var condState otyp
   240  		// var condNonEq otyp
   241  		// var condEq otyp
   242  		var err error
   243  		for k := 0; k < n; k++ {
   244  			if ctr.inBuckets[k] == 0 {
   245  				continue
   246  			}
   247  			if zvals[k] == 0 { // 2.1 : probe tuple has null
   248  				condState, err = ctr.EvalEntire(bat, ctr.bat, i+k, proc, ctr.rewriteCond)
   249  				if err != nil {
   250  					return err
   251  				}
   252  				ctr.handleResultType(i+k, condState)
   253  			} else if vals[k] > 0 { // 2.2.1 : condEq is condTrue in JoinMap
   254  				condState, err = ctr.nonEqJoinInMap(ap, mSels, vals, k, i, proc, bat)
   255  				if err != nil {
   256  					return err
   257  				}
   258  				if condState == condTrue { // 2.2.1.1 : condNonEq is condTrue in JoinMap
   259  					ctr.markVals[i+k] = true
   260  				} else if condState == condUnkown { // 2.2.1.2 : condNonEq is condUnkown in JoinMap
   261  					nulls.Add(ctr.markNulls, uint64(i+k))
   262  				} else { // 2.2.1.3 : condNonEq is condFalse in JoinMap, further check in nullSels
   263  					if len(ctr.nullSels) == 0 {
   264  						ctr.handleResultType(i+k, condFalse)
   265  						continue
   266  					}
   267  					condState, err = ctr.EvalEntire(bat, ctr.nullWithBatch, i+k, proc, ctr.rewriteCond)
   268  					if err != nil {
   269  						return err
   270  					}
   271  					ctr.handleResultType(i+k, condState)
   272  				}
   273  			} else { // 2.2.2 : condEq in condFalse in JoinMap, further check in nullSels
   274  				if len(ctr.nullSels) == 0 {
   275  					ctr.handleResultType(i+k, condFalse)
   276  					continue
   277  				}
   278  				condState, err = ctr.EvalEntire(bat, ctr.nullWithBatch, i+k, proc, ctr.rewriteCond)
   279  				if err != nil {
   280  					return err
   281  				}
   282  				ctr.handleResultType(i+k, condState)
   283  			}
   284  		}
   285  	}
   286  	for i, pos := range ap.Result {
   287  		if pos >= 0 {
   288  			ctr.rbat.Vecs[i] = bat.Vecs[pos]
   289  			bat.Vecs[pos] = nil
   290  		} else {
   291  			markVec.SetNulls(ctr.markNulls)
   292  			ctr.rbat.Vecs[i] = markVec
   293  		}
   294  	}
   295  	ctr.rbat.AddRowCount(bat.RowCount())
   296  	anal.Output(ctr.rbat, isLast)
   297  	result.Batch = ctr.rbat
   298  	return nil
   299  }
   300  
   301  // store the results of the calculation on the probe side of the equation condition
   302  func (ctr *container) evalJoinProbeCondition(bat *batch.Batch, proc *process.Process) error {
   303  	for i := range ctr.evecs {
   304  		vec, err := ctr.evecs[i].executor.Eval(proc, []*batch.Batch{bat})
   305  		if err != nil {
   306  			ctr.cleanEvalVectors()
   307  			return err
   308  		}
   309  		ctr.vecs[i] = vec
   310  		ctr.evecs[i].vec = vec
   311  	}
   312  	return nil
   313  }
   314  
   315  // store the results of the calculation on the build side of the equation condition
   316  func (ctr *container) evalJoinBuildCondition(bat *batch.Batch, proc *process.Process) error {
   317  	for i := range ctr.buildEqEvecs {
   318  		vec, err := ctr.buildEqEvecs[i].executor.Eval(proc, []*batch.Batch{bat})
   319  		if err != nil {
   320  			ctr.cleanEvalVectors()
   321  			return err
   322  		}
   323  		ctr.buildEqVec[i] = vec
   324  		ctr.buildEqEvecs[i].vec = vec
   325  	}
   326  	return nil
   327  }
   328  
   329  // calculate the state of non-equal conditions for those tuples in JoinMap
   330  func (ctr *container) nonEqJoinInMap(ap *Argument, mSels [][]int32, vals []uint64, k int, i int, proc *process.Process, bat *batch.Batch) (otyp, error) {
   331  	if ap.Cond != nil {
   332  		condState := condFalse
   333  		if ap.HashOnPK {
   334  			if ctr.joinBat1 == nil {
   335  				ctr.joinBat1, ctr.cfs1 = colexec.NewJoinBatch(bat, proc.Mp())
   336  			}
   337  			if ctr.joinBat2 == nil {
   338  				ctr.joinBat2, ctr.cfs2 = colexec.NewJoinBatch(ctr.bat, proc.Mp())
   339  			}
   340  			if err := colexec.SetJoinBatchValues(ctr.joinBat1, bat, int64(i+k),
   341  				1, ctr.cfs1); err != nil {
   342  				return condUnkown, err
   343  			}
   344  			if err := colexec.SetJoinBatchValues(ctr.joinBat2, ctr.bat, int64(vals[k]-1),
   345  				1, ctr.cfs2); err != nil {
   346  				return condUnkown, err
   347  			}
   348  			vec, err := ctr.expr.Eval(proc, []*batch.Batch{ctr.joinBat1, ctr.joinBat2})
   349  			if err != nil {
   350  				return condUnkown, err
   351  			}
   352  			if vec.GetNulls().Contains(0) {
   353  				condState = condUnkown
   354  			}
   355  			bs := vector.MustFixedCol[bool](vec)
   356  			if bs[0] {
   357  				condState = condTrue
   358  			}
   359  		} else {
   360  			sels := mSels[vals[k]-1]
   361  			if ctr.joinBat1 == nil {
   362  				ctr.joinBat1, ctr.cfs1 = colexec.NewJoinBatch(bat, proc.Mp())
   363  			}
   364  			if ctr.joinBat2 == nil {
   365  				ctr.joinBat2, ctr.cfs2 = colexec.NewJoinBatch(ctr.bat, proc.Mp())
   366  			}
   367  			for _, sel := range sels {
   368  				if err := colexec.SetJoinBatchValues(ctr.joinBat1, bat, int64(i+k),
   369  					1, ctr.cfs1); err != nil {
   370  					return condUnkown, err
   371  				}
   372  				if err := colexec.SetJoinBatchValues(ctr.joinBat2, ctr.bat, int64(sel),
   373  					1, ctr.cfs2); err != nil {
   374  					return condUnkown, err
   375  				}
   376  				vec, err := ctr.expr.Eval(proc, []*batch.Batch{ctr.joinBat1, ctr.joinBat2})
   377  				if err != nil {
   378  					return condUnkown, err
   379  				}
   380  				if vec.GetNulls().Contains(0) {
   381  					condState = condUnkown
   382  				}
   383  				bs := vector.MustFixedCol[bool](vec)
   384  				if bs[0] {
   385  					condState = condTrue
   386  					break
   387  				}
   388  			}
   389  		}
   390  		return condState, nil
   391  	} else {
   392  		return condTrue, nil
   393  	}
   394  }
   395  
   396  func (ctr *container) EvalEntire(pbat, bat *batch.Batch, idx int, proc *process.Process, cond *plan.Expr) (otyp, error) {
   397  	if cond == nil {
   398  		return condTrue, nil
   399  	}
   400  	if ctr.joinBat == nil {
   401  		ctr.joinBat, ctr.cfs = colexec.NewJoinBatch(pbat, proc.Mp())
   402  	}
   403  	if err := colexec.SetJoinBatchValues(ctr.joinBat, pbat, int64(idx), ctr.bat.RowCount(), ctr.cfs); err != nil {
   404  		return condUnkown, err
   405  	}
   406  	vec, err := ctr.expr.Eval(proc, []*batch.Batch{ctr.joinBat, ctr.bat})
   407  	if err != nil {
   408  		return condUnkown, err
   409  	}
   410  
   411  	bs := vector.GenerateFunctionFixedTypeParameter[bool](vec)
   412  	j := uint64(vec.Length())
   413  	hasNull := false
   414  	for i := uint64(0); i < j; i++ {
   415  		b, null := bs.GetValue(i)
   416  		if null {
   417  			hasNull = true
   418  		} else if b {
   419  			return condTrue, nil
   420  		}
   421  	}
   422  	if hasNull {
   423  		return condUnkown, nil
   424  	}
   425  	return condFalse, nil
   426  }
   427  
   428  // collect the idx of tuple which contains null values
   429  func (ctr *container) evalNullSels(bat *batch.Batch) {
   430  	joinMap := bat.AuxData.(*hashmap.JoinMap)
   431  	jmSels := joinMap.Sels()
   432  	selsMap := make(map[int32]bool)
   433  	for _, sel := range jmSels {
   434  		for _, i := range sel {
   435  			selsMap[i] = true
   436  		}
   437  	}
   438  	var nullSels []int64
   439  	for i := 0; i < bat.RowCount(); i++ {
   440  		if selsMap[int32(i)] {
   441  			ctr.sels = append(ctr.sels, int64(i))
   442  			continue
   443  		}
   444  		nullSels = append(nullSels, int64(i))
   445  	}
   446  	ctr.nullSels = nullSels
   447  }
   448  
   449  // mark probe tuple state
   450  func (ctr *container) handleResultType(idx int, r otyp) {
   451  	switch r {
   452  	case condTrue:
   453  		ctr.markVals[idx] = true
   454  	case condFalse:
   455  		ctr.markVals[idx] = false
   456  	case condUnkown:
   457  		nulls.Add(ctr.markNulls, uint64(idx))
   458  	}
   459  }
   460  
   461  func DumpBatch(originBatch *batch.Batch, proc *process.Process, sels []int64) (*batch.Batch, error) {
   462  	length := originBatch.RowCount()
   463  	flags := make([]uint8, length)
   464  	for _, sel := range sels {
   465  		flags[sel] = 1
   466  	}
   467  	bat := batch.NewWithSize(len(originBatch.Vecs))
   468  	for i, vec := range originBatch.Vecs {
   469  		bat.Vecs[i] = proc.GetVector(*vec.GetType())
   470  	}
   471  	if len(sels) == 0 {
   472  		return bat, nil
   473  	}
   474  	for i, vec := range originBatch.Vecs {
   475  		err := bat.Vecs[i].UnionBatch(vec, 0, length, flags, proc.Mp())
   476  		if err != nil {
   477  			proc.PutBatch(bat)
   478  			return nil, err
   479  		}
   480  	}
   481  	bat.AddRowCount(originBatch.RowCount())
   482  	return bat, nil
   483  }