github.com/matrixorigin/matrixone@v1.2.0/pkg/sql/colexec/mergegroup/group.go (about)

     1  // Copyright 2021 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package mergegroup
    16  
    17  import (
    18  	"bytes"
    19  	"github.com/matrixorigin/matrixone/pkg/common/hashmap"
    20  	"github.com/matrixorigin/matrixone/pkg/container/batch"
    21  	"github.com/matrixorigin/matrixone/pkg/container/types"
    22  	"github.com/matrixorigin/matrixone/pkg/vm"
    23  	"github.com/matrixorigin/matrixone/pkg/vm/process"
    24  	"runtime"
    25  )
    26  
    27  const argName = "merge_group"
    28  
    29  func (arg *Argument) String(buf *bytes.Buffer) {
    30  	buf.WriteString(argName)
    31  	buf.WriteString(": mergeroup()")
    32  }
    33  
    34  func (arg *Argument) Prepare(proc *process.Process) error {
    35  	ap := arg
    36  	ap.ctr = new(container)
    37  	ap.ctr.InitReceiver(proc, true)
    38  	ap.ctr.inserted = make([]uint8, hashmap.UnitLimit)
    39  	ap.ctr.zInserted = make([]uint8, hashmap.UnitLimit)
    40  	return nil
    41  }
    42  
    43  func (arg *Argument) Call(proc *process.Process) (vm.CallResult, error) {
    44  	if err, isCancel := vm.CancelCheck(proc); isCancel {
    45  		return vm.CancelResult, err
    46  	}
    47  
    48  	ap := arg
    49  	ctr := ap.ctr
    50  	anal := proc.GetAnalyze(arg.GetIdx(), arg.GetParallelIdx(), arg.GetParallelMajor())
    51  	anal.Start()
    52  	defer anal.Stop()
    53  	result := vm.NewCallResult()
    54  	for {
    55  		switch ctr.state {
    56  		case Build:
    57  			for {
    58  				bat, end, err := ctr.ReceiveFromAllRegs(anal)
    59  				if err != nil {
    60  					result.Status = vm.ExecStop
    61  					return result, nil
    62  				}
    63  
    64  				if end {
    65  					break
    66  				}
    67  				anal.Input(bat, arg.GetIsFirst())
    68  				if err = ctr.process(bat, proc); err != nil {
    69  					bat.Clean(proc.Mp())
    70  					return result, err
    71  				}
    72  			}
    73  			ctr.state = Eval
    74  
    75  		case Eval:
    76  			if ctr.bat != nil {
    77  				if ap.NeedEval {
    78  					for i, agg := range ctr.bat.Aggs {
    79  						if len(ap.PartialResults) > i && ap.PartialResults[i] != nil {
    80  							if err := agg.SetExtraInformation(ap.PartialResults[i], 0); err != nil {
    81  								return result, err
    82  							}
    83  						}
    84  						vec, err := agg.Flush()
    85  						if err != nil {
    86  							ctr.state = End
    87  							return result, err
    88  						}
    89  						ctr.bat.Aggs[i] = nil
    90  						ctr.bat.Vecs = append(ctr.bat.Vecs, vec)
    91  						if vec != nil {
    92  							anal.Alloc(int64(vec.Size()))
    93  						}
    94  
    95  						agg.Free()
    96  					}
    97  					ctr.bat.Aggs = nil
    98  				}
    99  				anal.Output(ctr.bat, arg.GetIsLast())
   100  				result.Batch = ctr.bat
   101  			}
   102  			ctr.state = End
   103  			return result, nil
   104  
   105  		case End:
   106  			result.Batch = nil
   107  			result.Status = vm.ExecStop
   108  			return result, nil
   109  		}
   110  	}
   111  }
   112  
   113  func (ctr *container) process(bat *batch.Batch, proc *process.Process) error {
   114  	var err error
   115  
   116  	if ctr.bat == nil {
   117  		keyWidth := 0
   118  		groupVecsNullable := false
   119  
   120  		for _, vec := range bat.Vecs {
   121  			groupVecsNullable = groupVecsNullable || (!vec.GetType().GetNotNull())
   122  		}
   123  
   124  		for _, vec := range bat.Vecs {
   125  			width := vec.GetType().TypeSize()
   126  			if vec.GetType().IsVarlen() {
   127  				if vec.GetType().Width == 0 {
   128  					switch vec.GetType().Oid {
   129  					case types.T_array_float32:
   130  						width = 128 * 4
   131  					case types.T_array_float64:
   132  						width = 128 * 8
   133  					default:
   134  						width = 128
   135  					}
   136  				} else {
   137  					switch vec.GetType().Oid {
   138  					case types.T_array_float32:
   139  						width = int(vec.GetType().Width) * 4
   140  					case types.T_array_float64:
   141  						width = int(vec.GetType().Width) * 8
   142  					default:
   143  						width = int(vec.GetType().Width)
   144  					}
   145  				}
   146  			}
   147  			keyWidth += width
   148  			if groupVecsNullable {
   149  				keyWidth += 1
   150  			}
   151  		}
   152  
   153  		switch {
   154  		case keyWidth == 0:
   155  			// no group by.
   156  			ctr.typ = H0
   157  
   158  		case keyWidth <= 8:
   159  			ctr.typ = H8
   160  			if ctr.intHashMap, err = hashmap.NewIntHashMap(groupVecsNullable, 0, 0, proc.Mp()); err != nil {
   161  				return err
   162  			}
   163  		default:
   164  			ctr.typ = HStr
   165  			if ctr.strHashMap, err = hashmap.NewStrMap(groupVecsNullable, 0, 0, proc.Mp()); err != nil {
   166  				return err
   167  			}
   168  		}
   169  	}
   170  
   171  	switch ctr.typ {
   172  	case H0:
   173  		err = ctr.processH0(bat, proc)
   174  	case H8:
   175  		err = ctr.processH8(bat, proc)
   176  	default:
   177  		err = ctr.processHStr(bat, proc)
   178  	}
   179  	return err
   180  }
   181  
   182  func (ctr *container) processH0(bat *batch.Batch, proc *process.Process) error {
   183  	if ctr.bat == nil {
   184  		ctr.bat = bat
   185  		return nil
   186  	}
   187  	defer proc.PutBatch(bat)
   188  	ctr.bat.SetRowCount(1)
   189  
   190  	for i, agg := range ctr.bat.Aggs {
   191  		err := agg.Merge(bat.Aggs[i], 0, 0)
   192  		if err != nil {
   193  			return err
   194  		}
   195  	}
   196  	return nil
   197  }
   198  
   199  func (ctr *container) processH8(bat *batch.Batch, proc *process.Process) error {
   200  	count := bat.RowCount()
   201  	itr := ctr.intHashMap.NewIterator()
   202  	flg := ctr.bat == nil
   203  	if !flg {
   204  		defer proc.PutBatch(bat)
   205  	}
   206  	for i := 0; i < count; i += hashmap.UnitLimit {
   207  		if i%(hashmap.UnitLimit*32) == 0 {
   208  			runtime.Gosched()
   209  		}
   210  		n := count - i
   211  		if n > hashmap.UnitLimit {
   212  			n = hashmap.UnitLimit
   213  		}
   214  		rowCount := ctr.intHashMap.GroupCount()
   215  		vals, _, err := itr.Insert(i, n, bat.Vecs)
   216  		if err != nil {
   217  			return err
   218  		}
   219  		if !flg {
   220  			if err = ctr.batchFill(i, n, bat, vals, rowCount, proc); err != nil {
   221  				return err
   222  			}
   223  		}
   224  	}
   225  	if flg {
   226  		ctr.bat = bat
   227  	}
   228  	return nil
   229  }
   230  
   231  func (ctr *container) processHStr(bat *batch.Batch, proc *process.Process) error {
   232  	count := bat.RowCount()
   233  	itr := ctr.strHashMap.NewIterator()
   234  	flg := ctr.bat == nil
   235  	if !flg {
   236  		defer proc.PutBatch(bat)
   237  	}
   238  	for i := 0; i < count; i += hashmap.UnitLimit { // batch
   239  		if i%(hashmap.UnitLimit*32) == 0 {
   240  			runtime.Gosched()
   241  		}
   242  		n := count - i
   243  		if n > hashmap.UnitLimit {
   244  			n = hashmap.UnitLimit
   245  		}
   246  		rowCount := ctr.strHashMap.GroupCount()
   247  		vals, _, err := itr.Insert(i, n, bat.Vecs)
   248  		if err != nil {
   249  			return err
   250  		}
   251  		if !flg {
   252  			if err := ctr.batchFill(i, n, bat, vals, rowCount, proc); err != nil {
   253  				return err
   254  			}
   255  		}
   256  	}
   257  	if flg {
   258  		ctr.bat = bat
   259  	}
   260  	return nil
   261  }
   262  
   263  func (ctr *container) batchFill(i int, n int, bat *batch.Batch, vals []uint64, hashRows uint64, proc *process.Process) error {
   264  	cnt := 0
   265  	copy(ctr.inserted[:n], ctr.zInserted[:n])
   266  	for k, v := range vals {
   267  		if v > hashRows {
   268  			ctr.inserted[k] = 1
   269  			hashRows++
   270  			cnt++
   271  		}
   272  	}
   273  	ctr.bat.AddRowCount(cnt)
   274  
   275  	if cnt > 0 {
   276  		for j, vec := range ctr.bat.Vecs {
   277  			if err := vec.UnionBatch(bat.Vecs[j], int64(i), cnt, ctr.inserted[:n], proc.Mp()); err != nil {
   278  				return err
   279  			}
   280  		}
   281  		for _, agg := range ctr.bat.Aggs {
   282  			if err := agg.GroupGrow(cnt); err != nil {
   283  				return err
   284  			}
   285  		}
   286  	}
   287  	for j, agg := range ctr.bat.Aggs {
   288  		if err := agg.BatchMerge(bat.Aggs[j], i, vals[:n]); err != nil {
   289  			return err
   290  		}
   291  	}
   292  	return nil
   293  }