github.com/matrixorigin/matrixone@v1.2.0/pkg/sql/colexec/anti/join.go (about)

     1  // Copyright 2021 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package anti
    16  
    17  import (
    18  	"bytes"
    19  
    20  	"github.com/matrixorigin/matrixone/pkg/common/hashmap"
    21  	"github.com/matrixorigin/matrixone/pkg/container/batch"
    22  	"github.com/matrixorigin/matrixone/pkg/container/vector"
    23  	"github.com/matrixorigin/matrixone/pkg/sql/colexec"
    24  	"github.com/matrixorigin/matrixone/pkg/vm"
    25  	"github.com/matrixorigin/matrixone/pkg/vm/process"
    26  )
    27  
    28  const argName = "anti"
    29  
    30  func (arg *Argument) String(buf *bytes.Buffer) {
    31  	buf.WriteString(argName)
    32  	buf.WriteString(": anti join ")
    33  }
    34  
    35  func (arg *Argument) Prepare(proc *process.Process) (err error) {
    36  	ap := arg
    37  	ap.ctr = new(container)
    38  	ap.ctr.InitReceiver(proc, false)
    39  	ap.ctr.inBuckets = make([]uint8, hashmap.UnitLimit)
    40  
    41  	ap.ctr.vecs = make([]*vector.Vector, len(ap.Conditions[0]))
    42  	ap.ctr.executorForVecs, err = colexec.NewExpressionExecutorsFromPlanExpressions(proc, ap.Conditions[0])
    43  	if err != nil {
    44  		return err
    45  	}
    46  
    47  	if ap.Cond != nil {
    48  		ap.ctr.expr, err = colexec.NewExpressionExecutor(proc, ap.Cond)
    49  	}
    50  	return err
    51  }
    52  
    53  func (arg *Argument) Call(proc *process.Process) (vm.CallResult, error) {
    54  	if err, isCancel := vm.CancelCheck(proc); isCancel {
    55  		return vm.CancelResult, err
    56  	}
    57  
    58  	anal := proc.GetAnalyze(arg.GetIdx(), arg.GetParallelIdx(), arg.GetParallelMajor())
    59  	anal.Start()
    60  	defer anal.Stop()
    61  	ap := arg
    62  	result := vm.NewCallResult()
    63  	ctr := ap.ctr
    64  	for {
    65  		switch ctr.state {
    66  		case Build:
    67  			if err := ctr.build(anal); err != nil {
    68  				return result, err
    69  			}
    70  			ctr.state = Probe
    71  
    72  		case Probe:
    73  			if ap.bat == nil {
    74  				bat, _, err := ctr.ReceiveFromSingleReg(0, anal)
    75  				if err != nil {
    76  					return result, err
    77  				}
    78  
    79  				if bat == nil {
    80  					ctr.state = End
    81  					continue
    82  				}
    83  				if bat.Last() {
    84  					result.Batch = bat
    85  					return result, nil
    86  				}
    87  				if bat.IsEmpty() {
    88  					continue
    89  				}
    90  
    91  				ap.bat = bat
    92  				ap.lastrow = 0
    93  			}
    94  
    95  			if ctr.mp == nil {
    96  				err := ctr.emptyProbe(ap, proc, anal, arg.GetIsFirst(), arg.GetIsLast(), &result)
    97  				return result, err
    98  			} else {
    99  				err := ctr.probe(ap, proc, anal, arg.GetIsFirst(), arg.GetIsLast(), &result)
   100  				if ap.lastrow == 0 {
   101  					proc.PutBatch(ap.bat)
   102  					ap.bat = nil
   103  				}
   104  				return result, err
   105  			}
   106  
   107  		default:
   108  			result.Batch = nil
   109  			result.Status = vm.ExecStop
   110  			return result, nil
   111  		}
   112  	}
   113  }
   114  
   115  func (ctr *container) receiveHashMap(anal process.Analyze) error {
   116  	bat, _, err := ctr.ReceiveFromSingleReg(1, anal)
   117  	if err != nil {
   118  		return err
   119  	}
   120  	if bat != nil && bat.AuxData != nil {
   121  		ctr.mp = bat.DupJmAuxData()
   122  		ctr.hasNull = ctr.mp.HasNull()
   123  		ctr.maxAllocSize = max(ctr.maxAllocSize, ctr.mp.Size())
   124  	}
   125  	return nil
   126  }
   127  
   128  func (ctr *container) receiveBatch(anal process.Analyze) error {
   129  	for {
   130  		bat, _, err := ctr.ReceiveFromSingleReg(1, anal)
   131  		if err != nil {
   132  			return err
   133  		}
   134  		if bat != nil {
   135  			ctr.batchRowCount += bat.RowCount()
   136  			ctr.batches = append(ctr.batches, bat)
   137  		} else {
   138  			break
   139  		}
   140  	}
   141  	for i := 0; i < len(ctr.batches)-1; i++ {
   142  		if ctr.batches[i].RowCount() != colexec.DefaultBatchSize {
   143  			panic("wrong batch received for hash build!")
   144  		}
   145  	}
   146  	return nil
   147  }
   148  
   149  func (ctr *container) build(anal process.Analyze) error {
   150  	err := ctr.receiveHashMap(anal)
   151  	if err != nil {
   152  		return err
   153  	}
   154  	return ctr.receiveBatch(anal)
   155  }
   156  
   157  func (ctr *container) emptyProbe(ap *Argument, proc *process.Process, anal process.Analyze, isFirst bool, isLast bool, result *vm.CallResult) error {
   158  	anal.Input(ap.bat, isFirst)
   159  	if ctr.rbat != nil {
   160  		proc.PutBatch(ctr.rbat)
   161  		ctr.rbat = nil
   162  	}
   163  	ctr.rbat = batch.NewWithSize(len(ap.Result))
   164  	for i, pos := range ap.Result {
   165  		ctr.rbat.Vecs[i] = proc.GetVector(*ap.bat.Vecs[pos].GetType())
   166  		// for anti join, if left batch is sorted , then output batch is sorted
   167  		ctr.rbat.Vecs[i].SetSorted(ap.bat.Vecs[pos].GetSorted())
   168  	}
   169  	count := ap.bat.RowCount()
   170  	for i := ap.lastrow; i < count; i += hashmap.UnitLimit {
   171  		if ctr.rbat.RowCount() >= colexec.DefaultBatchSize {
   172  			anal.Output(ctr.rbat, isLast)
   173  			result.Batch = ctr.rbat
   174  			ap.lastrow = i
   175  			return nil
   176  		}
   177  		n := count - i
   178  		if n > hashmap.UnitLimit {
   179  			n = hashmap.UnitLimit
   180  		}
   181  		for k := 0; k < n; k++ {
   182  			for j, pos := range ap.Result {
   183  				if err := ctr.rbat.Vecs[j].UnionOne(ap.bat.Vecs[pos], int64(i+k), proc.Mp()); err != nil {
   184  					return err
   185  				}
   186  			}
   187  		}
   188  		ctr.rbat.AddRowCount(n)
   189  	}
   190  	anal.Output(ctr.rbat, isLast)
   191  	result.Batch = ctr.rbat
   192  	proc.PutBatch(ap.bat)
   193  	ap.lastrow = 0
   194  	ap.bat = nil
   195  	return nil
   196  }
   197  
   198  func (ctr *container) probe(ap *Argument, proc *process.Process, anal process.Analyze, isFirst bool, isLast bool, result *vm.CallResult) error {
   199  
   200  	anal.Input(ap.bat, isFirst)
   201  	if ctr.rbat != nil {
   202  		proc.PutBatch(ctr.rbat)
   203  		ctr.rbat = nil
   204  	}
   205  	ctr.rbat = batch.NewWithSize(len(ap.Result))
   206  	for i, pos := range ap.Result {
   207  		ctr.rbat.Vecs[i] = proc.GetVector(*ap.bat.Vecs[pos].GetType())
   208  		// for anti join, if left batch is sorted , then output batch is sorted
   209  		ctr.rbat.Vecs[i].SetSorted(ap.bat.Vecs[pos].GetSorted())
   210  	}
   211  	if ctr.batchRowCount == 1 && ctr.hasNull {
   212  		result.Batch = ctr.rbat
   213  		anal.Output(ctr.rbat, isLast)
   214  		return nil
   215  	}
   216  
   217  	if err := ctr.evalJoinCondition(ap.bat, proc); err != nil {
   218  		return err
   219  	}
   220  
   221  	if ctr.joinBat1 == nil {
   222  		ctr.joinBat1, ctr.cfs1 = colexec.NewJoinBatch(ap.bat, proc.Mp())
   223  	}
   224  	if ctr.joinBat2 == nil && ctr.batchRowCount > 0 {
   225  		ctr.joinBat2, ctr.cfs2 = colexec.NewJoinBatch(ctr.batches[0], proc.Mp())
   226  	}
   227  
   228  	count := ap.bat.RowCount()
   229  	mSels := ctr.mp.Sels()
   230  	itr := ctr.mp.NewIterator()
   231  	eligible := make([]int32, 0, hashmap.UnitLimit)
   232  	for i := ap.lastrow; i < count; i += hashmap.UnitLimit {
   233  		if ctr.rbat.RowCount() >= colexec.DefaultBatchSize {
   234  			anal.Output(ctr.rbat, isLast)
   235  			result.Batch = ctr.rbat
   236  			ap.lastrow = i
   237  			return nil
   238  		}
   239  		n := count - i
   240  		if n > hashmap.UnitLimit {
   241  			n = hashmap.UnitLimit
   242  		}
   243  		copy(ctr.inBuckets, hashmap.OneUInt8s)
   244  		vals, zvals := itr.Find(i, n, ctr.vecs, ctr.inBuckets)
   245  
   246  		rowCountIncrease := 0
   247  		for k := 0; k < n; k++ {
   248  			if ctr.inBuckets[k] == 0 || zvals[k] == 0 {
   249  				continue
   250  			}
   251  			if vals[k] == 0 {
   252  				eligible = append(eligible, int32(i+k))
   253  				rowCountIncrease++
   254  				continue
   255  			}
   256  			if ap.Cond != nil {
   257  				if ap.HashOnPK {
   258  					idx1, idx2 := int64(vals[k]-1)/colexec.DefaultBatchSize, int64(vals[k]-1)%colexec.DefaultBatchSize
   259  					if err := colexec.SetJoinBatchValues(ctr.joinBat1, ap.bat, int64(i+k),
   260  						1, ctr.cfs1); err != nil {
   261  						return err
   262  					}
   263  					if err := colexec.SetJoinBatchValues(ctr.joinBat2, ctr.batches[idx1], idx2,
   264  						1, ctr.cfs2); err != nil {
   265  						return err
   266  					}
   267  					vec, err := ctr.expr.Eval(proc, []*batch.Batch{ctr.joinBat1, ctr.joinBat2})
   268  					if err != nil {
   269  						return err
   270  					}
   271  					if vec.IsConstNull() || vec.GetNulls().Contains(0) {
   272  						continue
   273  					}
   274  					bs := vector.MustFixedCol[bool](vec)
   275  					if bs[0] {
   276  						continue
   277  					}
   278  				} else {
   279  					matched := false // mark if any tuple satisfies the condition
   280  					sels := mSels[vals[k]-1]
   281  					for _, sel := range sels {
   282  						idx1, idx2 := sel/colexec.DefaultBatchSize, sel%colexec.DefaultBatchSize
   283  						if err := colexec.SetJoinBatchValues(ctr.joinBat1, ap.bat, int64(i+k),
   284  							1, ctr.cfs1); err != nil {
   285  							return err
   286  						}
   287  						if err := colexec.SetJoinBatchValues(ctr.joinBat2, ctr.batches[idx1], int64(idx2),
   288  							1, ctr.cfs2); err != nil {
   289  							return err
   290  						}
   291  						vec, err := ctr.expr.Eval(proc, []*batch.Batch{ctr.joinBat1, ctr.joinBat2})
   292  						if err != nil {
   293  							return err
   294  						}
   295  						if vec.IsConstNull() || vec.GetNulls().Contains(0) {
   296  							continue
   297  						}
   298  						bs := vector.MustFixedCol[bool](vec)
   299  						if bs[0] {
   300  							matched = true
   301  							break
   302  						}
   303  					}
   304  					if matched {
   305  						continue
   306  					}
   307  				}
   308  				eligible = append(eligible, int32(i+k))
   309  				rowCountIncrease++
   310  			}
   311  		}
   312  		ctr.rbat.SetRowCount(ctr.rbat.RowCount() + rowCountIncrease)
   313  
   314  		for j, pos := range ap.Result {
   315  			if err := ctr.rbat.Vecs[j].Union(ap.bat.Vecs[pos], eligible, proc.Mp()); err != nil {
   316  				return err
   317  			}
   318  		}
   319  		eligible = eligible[:0]
   320  	}
   321  	anal.Output(ctr.rbat, isLast)
   322  	result.Batch = ctr.rbat
   323  	ap.lastrow = 0
   324  	return nil
   325  }
   326  
   327  func (ctr *container) evalJoinCondition(bat *batch.Batch, proc *process.Process) error {
   328  	for i := range ctr.executorForVecs {
   329  		vec, err := ctr.executorForVecs[i].Eval(proc, []*batch.Batch{bat})
   330  		if err != nil {
   331  			return err
   332  		}
   333  		ctr.vecs[i] = vec
   334  	}
   335  	return nil
   336  }