github.com/matrixorigin/matrixone@v0.7.0/pkg/sql/colexec/intersectall/intersectall.go (about)

     1  // Copyright 2022 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package intersectall
    16  
    17  import (
    18  	"bytes"
    19  	"time"
    20  
    21  	"github.com/matrixorigin/matrixone/pkg/common/hashmap"
    22  	"github.com/matrixorigin/matrixone/pkg/container/batch"
    23  	"github.com/matrixorigin/matrixone/pkg/container/vector"
    24  	"github.com/matrixorigin/matrixone/pkg/vm/process"
    25  )
    26  
    27  const (
    28  	Build = iota
    29  	Probe
    30  	End
    31  )
    32  
    33  func String(_ any, buf *bytes.Buffer) {
    34  	buf.WriteString(" intersect all ")
    35  }
    36  
    37  func Prepare(proc *process.Process, arg any) error {
    38  	var err error
    39  	ap := arg.(*Argument)
    40  	ap.ctr = new(container)
    41  	if ap.ctr.hashTable, err = hashmap.NewStrMap(true, ap.IBucket, ap.NBucket, proc.Mp()); err != nil {
    42  		return err
    43  	}
    44  	ap.ctr.inBuckets = make([]uint8, hashmap.UnitLimit)
    45  	ap.ctr.inserted = make([]uint8, hashmap.UnitLimit)
    46  	ap.ctr.resetInserted = make([]uint8, hashmap.UnitLimit)
    47  	return nil
    48  }
    49  
    50  // Call is the execute method of `intersect all` operator
    51  // it built a hash table for right relation first.
    52  // and use an array to record how many times each key appears in right relation.
    53  // use values from left relation to probe and update the array.
    54  // throw away values that do not exist in the hash table.
    55  // preserve values that exist in the hash table (the minimum of the number of times that exist in either).
    56  func Call(idx int, proc *process.Process, argument any, isFirst bool, isLast bool) (bool, error) {
    57  	var err error
    58  	analyzer := proc.GetAnalyze(idx)
    59  	analyzer.Start()
    60  	defer analyzer.Stop()
    61  	arg := argument.(*Argument)
    62  	for {
    63  		switch arg.ctr.state {
    64  		case Build:
    65  			if err = arg.ctr.build(proc, analyzer, isFirst); err != nil {
    66  				arg.Free(proc, true)
    67  				return false, err
    68  			}
    69  			if arg.ctr.hashTable != nil {
    70  				analyzer.Alloc(arg.ctr.hashTable.Size())
    71  			}
    72  			arg.ctr.state = Probe
    73  
    74  		case Probe:
    75  			last := false
    76  			last, err = arg.ctr.probe(proc, analyzer, isFirst, isLast)
    77  			if err != nil {
    78  				arg.Free(proc, true)
    79  				return false, err
    80  			}
    81  			if last {
    82  				arg.ctr.state = End
    83  				continue
    84  			}
    85  			return false, nil
    86  
    87  		case End:
    88  			arg.Free(proc, false)
    89  			proc.SetInputBatch(nil)
    90  			return true, nil
    91  		}
    92  	}
    93  }
    94  
    95  // build use all batches from proc.Reg.MergeReceiver[1](right relation) to build the hash map.
    96  func (ctr *container) build(proc *process.Process, analyzer process.Analyze, isFirst bool) error {
    97  	for {
    98  		start := time.Now()
    99  		bat := <-proc.Reg.MergeReceivers[1].Ch
   100  		analyzer.WaitStop(start)
   101  
   102  		if bat == nil {
   103  			break
   104  		}
   105  		if len(bat.Zs) == 0 {
   106  			continue
   107  		}
   108  
   109  		analyzer.Input(bat, isFirst)
   110  		// build hashTable and a counter to record how many times each key appears
   111  		{
   112  			itr := ctr.hashTable.NewIterator()
   113  			count := bat.Length()
   114  			for i := 0; i < count; i += hashmap.UnitLimit {
   115  
   116  				n := count - i
   117  				if n > hashmap.UnitLimit {
   118  					n = hashmap.UnitLimit
   119  				}
   120  				vs, _, err := itr.Insert(i, n, bat.Vecs)
   121  				if err != nil {
   122  					bat.Clean(proc.Mp())
   123  					return err
   124  				}
   125  				if uint64(cap(ctr.counter)) < ctr.hashTable.GroupCount() {
   126  					gap := ctr.hashTable.GroupCount() - uint64(cap(ctr.counter))
   127  					ctr.counter = append(ctr.counter, make([]uint64, gap)...)
   128  				}
   129  				for _, v := range vs {
   130  					if v == 0 {
   131  						continue
   132  					}
   133  					ctr.counter[v-1]++
   134  				}
   135  			}
   136  			bat.Clean(proc.Mp())
   137  		}
   138  
   139  	}
   140  	return nil
   141  }
   142  
   143  // probe uses a batch from proc.Reg.MergeReceivers[0](left relation) to probe the hash map and update the counter.
   144  // If a row of the batch doesn't appear in the hash table, continue.
   145  // If a row of the batch appears in the hash table and the value of it in the ctr.counter is greater than 0,
   146  // send it to the next operator and counter--; else, continue.
   147  // if batch is the last one, return true, else return false.
   148  func (ctr *container) probe(proc *process.Process, analyzer process.Analyze, isFirst bool, isLast bool) (bool, error) {
   149  	for {
   150  		start := time.Now()
   151  		bat := <-proc.Reg.MergeReceivers[0].Ch
   152  		analyzer.WaitStop(start)
   153  
   154  		if bat == nil {
   155  			return true, nil
   156  		}
   157  		if len(bat.Zs) == 0 {
   158  			continue
   159  		}
   160  
   161  		analyzer.Input(bat, isFirst)
   162  		//data to send to the next op
   163  		var outputBat *batch.Batch
   164  		//counter to record whether a row should add to output batch or not
   165  		var cnt int
   166  
   167  		//init output batch
   168  		{
   169  			outputBat = batch.NewWithSize(len(bat.Vecs))
   170  			for i := range bat.Vecs {
   171  				outputBat.Vecs[i] = vector.New(bat.Vecs[i].Typ)
   172  			}
   173  		}
   174  
   175  		// probe hashTable
   176  		{
   177  			itr := ctr.hashTable.NewIterator()
   178  			count := bat.Length()
   179  			for i := 0; i < count; i += hashmap.UnitLimit {
   180  				n := count - i
   181  				if n > hashmap.UnitLimit {
   182  					n = hashmap.UnitLimit
   183  				}
   184  
   185  				copy(ctr.inBuckets, hashmap.OneUInt8s)
   186  				copy(ctr.inserted[:n], ctr.resetInserted[:n])
   187  				cnt = 0
   188  
   189  				vs, _ := itr.Find(i, n, bat.Vecs, ctr.inBuckets)
   190  
   191  				for j, v := range vs {
   192  					// not in the processed bucket
   193  					if ctr.inBuckets[j] == 0 {
   194  						continue
   195  					}
   196  
   197  					// not found
   198  					if v == 0 {
   199  						continue
   200  					}
   201  
   202  					//  all common row has been added into output batch
   203  					if ctr.counter[v-1] == 0 {
   204  						continue
   205  					}
   206  
   207  					ctr.inserted[j] = 1
   208  					ctr.counter[v-1]--
   209  					outputBat.Zs = append(outputBat.Zs, 1)
   210  					cnt++
   211  
   212  				}
   213  				if cnt > 0 {
   214  					for colNum := range bat.Vecs {
   215  						if err := vector.UnionBatch(outputBat.Vecs[colNum], bat.Vecs[colNum], int64(i), cnt, ctr.inserted[:n], proc.Mp()); err != nil {
   216  							bat.Clean(proc.Mp())
   217  							return false, err
   218  						}
   219  					}
   220  				}
   221  			}
   222  
   223  		}
   224  		analyzer.Alloc(int64(outputBat.Size()))
   225  		analyzer.Output(outputBat, isLast)
   226  		proc.SetInputBatch(outputBat)
   227  		bat.Clean(proc.Mp())
   228  		return false, nil
   229  	}
   230  }