github.com/matrixorigin/matrixone@v0.7.0/pkg/sql/colexec/intersect/intersect.go (about)

     1  // Copyright 2022 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package intersect
    16  
    17  import (
    18  	"bytes"
    19  	"time"
    20  
    21  	"github.com/matrixorigin/matrixone/pkg/common/hashmap"
    22  	"github.com/matrixorigin/matrixone/pkg/container/batch"
    23  	"github.com/matrixorigin/matrixone/pkg/container/vector"
    24  	"github.com/matrixorigin/matrixone/pkg/vm/process"
    25  )
    26  
    27  func String(_ any, buf *bytes.Buffer) {
    28  	buf.WriteString(" intersect ")
    29  }
    30  
    31  func Prepare(proc *process.Process, argument any) error {
    32  	var err error
    33  	arg := argument.(*Argument)
    34  	arg.ctr.btc = nil
    35  	arg.ctr.hashTable, err = hashmap.NewStrMap(true, arg.IBucket, arg.NBucket, proc.Mp())
    36  	if err != nil {
    37  		return err
    38  	}
    39  	arg.ctr.inBuckets = make([]uint8, hashmap.UnitLimit)
    40  	return nil
    41  }
    42  
    43  func Call(idx int, proc *process.Process, argument any, isFirst bool, isLast bool) (bool, error) {
    44  	arg := argument.(*Argument)
    45  
    46  	analyze := proc.GetAnalyze(idx)
    47  	analyze.Start()
    48  	defer analyze.Stop()
    49  
    50  	for {
    51  		switch arg.ctr.state {
    52  		case build:
    53  			if err := arg.ctr.buildHashTable(proc, analyze, 1, isFirst); err != nil {
    54  				arg.Free(proc, true)
    55  				return false, err
    56  			}
    57  			if arg.ctr.hashTable != nil {
    58  				analyze.Alloc(arg.ctr.hashTable.Size())
    59  			}
    60  			arg.ctr.state = probe
    61  
    62  		case probe:
    63  			var err error
    64  			isLast := false
    65  			if isLast, err = arg.ctr.probeHashTable(proc, analyze, 0, isFirst, isLast); err != nil {
    66  				return true, err
    67  			}
    68  			if isLast {
    69  				arg.ctr.state = end
    70  				continue
    71  			}
    72  
    73  			return false, nil
    74  
    75  		case end:
    76  			arg.Free(proc, false)
    77  			proc.SetInputBatch(nil)
    78  			return true, nil
    79  		}
    80  	}
    81  }
    82  
    83  // build hash table
    84  func (c *container) buildHashTable(proc *process.Process, analyse process.Analyze, idx int, isFirst bool) error {
    85  	for {
    86  		start := time.Now()
    87  		btc := <-proc.Reg.MergeReceivers[idx].Ch
    88  		analyse.WaitStop(start)
    89  
    90  		// last batch of block
    91  		if btc == nil {
    92  			break
    93  		}
    94  
    95  		// empty batch
    96  		if btc.Length() == 0 {
    97  			continue
    98  		}
    99  
   100  		analyse.Input(btc, isFirst)
   101  
   102  		cnt := btc.Length()
   103  		itr := c.hashTable.NewIterator()
   104  		for i := 0; i < cnt; i += hashmap.UnitLimit {
   105  			rowcnt := c.hashTable.GroupCount()
   106  
   107  			n := cnt - i
   108  			if n > hashmap.UnitLimit {
   109  				n = hashmap.UnitLimit
   110  			}
   111  
   112  			vs, zs, err := itr.Insert(i, n, btc.Vecs)
   113  			if err != nil {
   114  				btc.Clean(proc.Mp())
   115  				return err
   116  			}
   117  
   118  			for j, v := range vs {
   119  				if zs[j] == 0 {
   120  					continue
   121  				}
   122  
   123  				if v > rowcnt {
   124  					c.cnts = append(c.cnts, proc.Mp().GetSels())
   125  					c.cnts[v-1] = append(c.cnts[v-1], 1)
   126  					rowcnt++
   127  				}
   128  			}
   129  		}
   130  		btc.Clean(proc.Mp())
   131  	}
   132  	return nil
   133  }
   134  
   135  func (c *container) probeHashTable(proc *process.Process, analyze process.Analyze, idx int, isFirst bool, isLast bool) (bool, error) {
   136  	for {
   137  		start := time.Now()
   138  		btc := <-proc.Reg.MergeReceivers[idx].Ch
   139  		analyze.WaitStop(start)
   140  
   141  		// last batch of block
   142  		if btc == nil {
   143  			return true, nil
   144  		}
   145  
   146  		// empty batch
   147  		if btc.Length() == 0 {
   148  			continue
   149  		}
   150  
   151  		analyze.Input(btc, isFirst)
   152  
   153  		c.btc = batch.NewWithSize(len(btc.Vecs))
   154  		for i := range btc.Vecs {
   155  			c.btc.Vecs[i] = vector.New(btc.Vecs[i].Typ)
   156  		}
   157  		needInsert := make([]uint8, hashmap.UnitLimit)
   158  		resetsNeedInsert := make([]uint8, hashmap.UnitLimit)
   159  		cnt := btc.Length()
   160  		itr := c.hashTable.NewIterator()
   161  		for i := 0; i < cnt; i += hashmap.UnitLimit {
   162  			n := cnt - i
   163  			if n > hashmap.UnitLimit {
   164  				n = hashmap.UnitLimit
   165  			}
   166  
   167  			copy(c.inBuckets, hashmap.OneUInt8s)
   168  			copy(needInsert, resetsNeedInsert)
   169  			insertcnt := 0
   170  
   171  			vs, zs := itr.Find(i, n, btc.Vecs, c.inBuckets)
   172  
   173  			for j, v := range vs {
   174  				// not in the processed bucket
   175  				if c.inBuckets[j] == 0 {
   176  					continue
   177  				}
   178  
   179  				// null value
   180  				if zs[j] == 0 {
   181  					continue
   182  				}
   183  
   184  				// not found
   185  				if v == 0 {
   186  					continue
   187  				}
   188  
   189  				// has been added into output batch
   190  				if c.cnts[v-1][0] == 0 {
   191  					continue
   192  				}
   193  
   194  				needInsert[j] = 1
   195  				c.cnts[v-1][0] = 0
   196  				c.btc.Zs = append(c.btc.Zs, 1)
   197  				insertcnt++
   198  			}
   199  
   200  			if insertcnt > 0 {
   201  				for pos := range btc.Vecs {
   202  					if err := vector.UnionBatch(c.btc.Vecs[pos], btc.Vecs[pos], int64(i), insertcnt, needInsert, proc.Mp()); err != nil {
   203  						btc.Clean(proc.Mp())
   204  						return false, err
   205  					}
   206  				}
   207  			}
   208  		}
   209  
   210  		btc.Clean(proc.Mp())
   211  		analyze.Alloc(int64(c.btc.Size()))
   212  		analyze.Output(c.btc, isLast)
   213  		proc.SetInputBatch(c.btc)
   214  		return false, nil
   215  	}
   216  }