github.com/matrixorigin/matrixone@v1.2.0/pkg/sql/colexec/intersectall/intersectall.go (about)

     1  // Copyright 2022 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package intersectall
    16  
    17  import (
    18  	"bytes"
    19  
    20  	"github.com/matrixorigin/matrixone/pkg/common/hashmap"
    21  	"github.com/matrixorigin/matrixone/pkg/container/batch"
    22  	"github.com/matrixorigin/matrixone/pkg/vm"
    23  	"github.com/matrixorigin/matrixone/pkg/vm/process"
    24  )
    25  
    26  const (
    27  	Build = iota
    28  	Probe
    29  	End
    30  )
    31  
    32  const argName = "intersect_all"
    33  
    34  func (arg *Argument) String(buf *bytes.Buffer) {
    35  	buf.WriteString(argName)
    36  	buf.WriteString(": intersect all ")
    37  }
    38  
    39  func (arg *Argument) Prepare(proc *process.Process) error {
    40  	var err error
    41  	ap := arg
    42  	ap.ctr = new(container)
    43  	ap.ctr.InitReceiver(proc, false)
    44  	if ap.ctr.hashTable, err = hashmap.NewStrMap(true, ap.IBucket, ap.NBucket, proc.Mp()); err != nil {
    45  		return err
    46  	}
    47  	ap.ctr.inBuckets = make([]uint8, hashmap.UnitLimit)
    48  	ap.ctr.inserted = make([]uint8, hashmap.UnitLimit)
    49  	ap.ctr.resetInserted = make([]uint8, hashmap.UnitLimit)
    50  	return nil
    51  }
    52  
    53  // Call is the execute method of `intersect all` operator
    54  // it built a hash table for right relation first.
    55  // and use an array to record how many times each key appears in right relation.
    56  // use values from left relation to probe and update the array.
    57  // throw away values that do not exist in the hash table.
    58  // preserve values that exist in the hash table (the minimum of the number of times that exist in either).
    59  func (arg *Argument) Call(proc *process.Process) (vm.CallResult, error) {
    60  	if err, isCancel := vm.CancelCheck(proc); isCancel {
    61  		return vm.CancelResult, err
    62  	}
    63  
    64  	var err error
    65  	analyzer := proc.GetAnalyze(arg.GetIdx(), arg.GetParallelIdx(), arg.GetParallelMajor())
    66  	analyzer.Start()
    67  	defer analyzer.Stop()
    68  	result := vm.NewCallResult()
    69  	for {
    70  		switch arg.ctr.state {
    71  		case Build:
    72  			if err = arg.ctr.build(proc, analyzer, arg.GetIsFirst()); err != nil {
    73  				return result, err
    74  			}
    75  			if arg.ctr.hashTable != nil {
    76  				analyzer.Alloc(arg.ctr.hashTable.Size())
    77  			}
    78  			arg.ctr.state = Probe
    79  
    80  		case Probe:
    81  			last := false
    82  			last, err = arg.ctr.probe(proc, analyzer, arg.GetIsFirst(), arg.GetIsLast(), &result)
    83  			if err != nil {
    84  				return result, err
    85  			}
    86  			if last {
    87  				arg.ctr.state = End
    88  				continue
    89  			}
    90  			return result, nil
    91  
    92  		case End:
    93  			result.Batch = nil
    94  			result.Status = vm.ExecStop
    95  			return result, nil
    96  		}
    97  	}
    98  }
    99  
   100  // build use all batches from proc.Reg.MergeReceiver[1](right relation) to build the hash map.
   101  func (ctr *container) build(proc *process.Process, analyzer process.Analyze, isFirst bool) error {
   102  	for {
   103  		bat, _, err := ctr.ReceiveFromSingleReg(1, analyzer)
   104  		if err != nil {
   105  			return err
   106  		}
   107  
   108  		if bat == nil {
   109  			break
   110  		}
   111  		if bat.IsEmpty() {
   112  			proc.PutBatch(bat)
   113  			continue
   114  		}
   115  
   116  		analyzer.Input(bat, isFirst)
   117  		// build hashTable and a counter to record how many times each key appears
   118  		{
   119  			itr := ctr.hashTable.NewIterator()
   120  			count := bat.RowCount()
   121  			for i := 0; i < count; i += hashmap.UnitLimit {
   122  
   123  				n := count - i
   124  				if n > hashmap.UnitLimit {
   125  					n = hashmap.UnitLimit
   126  				}
   127  				vs, _, err := itr.Insert(i, n, bat.Vecs)
   128  				if err != nil {
   129  					bat.Clean(proc.Mp())
   130  					return err
   131  				}
   132  				if uint64(cap(ctr.counter)) < ctr.hashTable.GroupCount() {
   133  					gap := ctr.hashTable.GroupCount() - uint64(cap(ctr.counter))
   134  					ctr.counter = append(ctr.counter, make([]uint64, gap)...)
   135  				}
   136  				for _, v := range vs {
   137  					if v == 0 {
   138  						continue
   139  					}
   140  					ctr.counter[v-1]++
   141  				}
   142  			}
   143  			proc.PutBatch(bat)
   144  		}
   145  
   146  	}
   147  	return nil
   148  }
   149  
   150  // probe uses a batch from proc.Reg.MergeReceivers[0](left relation) to probe the hash map and update the counter.
   151  // If a row of the batch doesn't appear in the hash table, continue.
   152  // If a row of the batch appears in the hash table and the value of it in the ctr.counter is greater than 0,
   153  // send it to the next operator and counter--; else, continue.
   154  // if batch is the last one, return true, else return false.
   155  func (ctr *container) probe(proc *process.Process, analyzer process.Analyze, isFirst bool, isLast bool, result *vm.CallResult) (bool, error) {
   156  	if ctr.buf != nil {
   157  		proc.PutBatch(ctr.buf)
   158  		ctr.buf = nil
   159  	}
   160  	for {
   161  		bat, _, err := ctr.ReceiveFromSingleReg(0, analyzer)
   162  		if err != nil {
   163  			return false, err
   164  		}
   165  		if bat == nil {
   166  			return true, nil
   167  		}
   168  		analyzer.Input(bat, isFirst)
   169  		if bat.Last() {
   170  			ctr.buf = bat
   171  			result.Batch = ctr.buf
   172  			return false, nil
   173  		}
   174  		if bat.IsEmpty() {
   175  			proc.PutBatch(bat)
   176  			continue
   177  		}
   178  		//counter to record whether a row should add to output batch or not
   179  		var cnt int
   180  
   181  		//init output batch
   182  		ctr.buf = batch.NewWithSize(len(bat.Vecs))
   183  		for i := range bat.Vecs {
   184  			ctr.buf.Vecs[i] = proc.GetVector(*bat.Vecs[i].GetType())
   185  		}
   186  
   187  		// probe hashTable
   188  		{
   189  			itr := ctr.hashTable.NewIterator()
   190  			count := bat.RowCount()
   191  			for i := 0; i < count; i += hashmap.UnitLimit {
   192  				n := count - i
   193  				if n > hashmap.UnitLimit {
   194  					n = hashmap.UnitLimit
   195  				}
   196  
   197  				copy(ctr.inBuckets, hashmap.OneUInt8s)
   198  				copy(ctr.inserted[:n], ctr.resetInserted[:n])
   199  				cnt = 0
   200  
   201  				vs, _ := itr.Find(i, n, bat.Vecs, ctr.inBuckets)
   202  
   203  				for j, v := range vs {
   204  					// not in the processed bucket
   205  					if ctr.inBuckets[j] == 0 {
   206  						continue
   207  					}
   208  
   209  					// not found
   210  					if v == 0 {
   211  						continue
   212  					}
   213  
   214  					//  all common row has been added into output batch
   215  					if ctr.counter[v-1] == 0 {
   216  						continue
   217  					}
   218  
   219  					ctr.inserted[j] = 1
   220  					ctr.counter[v-1]--
   221  					cnt++
   222  
   223  				}
   224  				ctr.buf.AddRowCount(cnt)
   225  
   226  				if cnt > 0 {
   227  					for colNum := range bat.Vecs {
   228  						if err := ctr.buf.Vecs[colNum].UnionBatch(bat.Vecs[colNum], int64(i), cnt, ctr.inserted[:n], proc.Mp()); err != nil {
   229  							bat.Clean(proc.Mp())
   230  							return false, err
   231  						}
   232  					}
   233  				}
   234  			}
   235  
   236  		}
   237  		analyzer.Alloc(int64(ctr.buf.Size()))
   238  		analyzer.Output(ctr.buf, isLast)
   239  
   240  		result.Batch = ctr.buf
   241  		proc.PutBatch(bat)
   242  		return false, nil
   243  	}
   244  }