github.com/matrixorigin/matrixone@v1.2.0/pkg/sql/colexec/group/group.go (about)

     1  // Copyright 2021 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package group
    16  
    17  import (
    18  	"bytes"
    19  	"fmt"
    20  	"runtime"
    21  
    22  	"github.com/matrixorigin/matrixone/pkg/common/hashmap"
    23  	"github.com/matrixorigin/matrixone/pkg/container/batch"
    24  	"github.com/matrixorigin/matrixone/pkg/container/types"
    25  	"github.com/matrixorigin/matrixone/pkg/sql/colexec/aggexec"
    26  	"github.com/matrixorigin/matrixone/pkg/sql/plan/function"
    27  	"github.com/matrixorigin/matrixone/pkg/vm"
    28  	"github.com/matrixorigin/matrixone/pkg/vm/process"
    29  )
    30  
    31  const argName = "group"
    32  
    33  func (arg *Argument) String(buf *bytes.Buffer) {
    34  	buf.WriteString(argName)
    35  	ap := arg
    36  	buf.WriteString(": group([")
    37  	for i, expr := range ap.Exprs {
    38  		if i > 0 {
    39  			buf.WriteString(", ")
    40  		}
    41  		buf.WriteString(fmt.Sprintf("%v", expr))
    42  	}
    43  	buf.WriteString("], [")
    44  	for i, ag := range ap.Aggs {
    45  		if i > 0 {
    46  			buf.WriteString(", ")
    47  		}
    48  		buf.WriteString(fmt.Sprintf("%v(%v)", function.GetAggFunctionNameByID(ag.GetAggID()), ag.GetArgExpressions()))
    49  	}
    50  	buf.WriteString("])")
    51  }
    52  
    53  func (arg *Argument) Prepare(proc *process.Process) (err error) {
    54  	ap := arg
    55  	ap.ctr = new(container)
    56  	ap.ctr.inserted = make([]uint8, hashmap.UnitLimit)
    57  	ap.ctr.zInserted = make([]uint8, hashmap.UnitLimit)
    58  
    59  	ctr := ap.ctr
    60  	ctr.state = vm.Build
    61  
    62  	// create executors for aggregation functions.
    63  	if len(ap.Aggs) > 0 {
    64  		ctr.aggVecs = make([]ExprEvalVector, len(ap.Aggs))
    65  		for i, ag := range ap.Aggs {
    66  			expressions := ag.GetArgExpressions()
    67  			if ctr.aggVecs[i], err = MakeEvalVector(proc, expressions); err != nil {
    68  				return err
    69  			}
    70  		}
    71  	}
    72  
    73  	// create executors for group by columns.
    74  	ctr.keyWidth = 0
    75  	if ap.Exprs != nil {
    76  		ctr.groupVecsNullable = false
    77  		ctr.groupVecs, err = MakeEvalVector(proc, ap.Exprs)
    78  		if err != nil {
    79  			return err
    80  		}
    81  		for _, gv := range ap.Exprs {
    82  			ctr.groupVecsNullable = ctr.groupVecsNullable || (!gv.Typ.NotNullable)
    83  		}
    84  
    85  		for _, expr := range ap.Exprs {
    86  			typ := expr.Typ
    87  			width := types.T(typ.Id).TypeLen()
    88  			if types.T(typ.Id).FixedLength() < 0 {
    89  				if typ.Width == 0 {
    90  					switch types.T(typ.Id) {
    91  					case types.T_array_float32:
    92  						width = 128 * 4
    93  					case types.T_array_float64:
    94  						width = 128 * 8
    95  					default:
    96  						width = 128
    97  					}
    98  				} else {
    99  					switch types.T(typ.Id) {
   100  					case types.T_array_float32:
   101  						width = int(typ.Width) * 4
   102  					case types.T_array_float64:
   103  						width = int(typ.Width) * 8
   104  					default:
   105  						width = int(typ.Width)
   106  					}
   107  				}
   108  			}
   109  			ctr.keyWidth += width
   110  			if ctr.groupVecsNullable {
   111  				ctr.keyWidth += 1
   112  			}
   113  		}
   114  	}
   115  
   116  	return nil
   117  }
   118  
   119  func (arg *Argument) Call(proc *process.Process) (vm.CallResult, error) {
   120  	if err, isCancel := vm.CancelCheck(proc); isCancel {
   121  		return vm.CancelResult, err
   122  	}
   123  
   124  	ap := arg
   125  	anal := proc.GetAnalyze(arg.GetIdx(), arg.GetParallelIdx(), arg.GetParallelMajor())
   126  	anal.Start()
   127  	defer anal.Stop()
   128  
   129  	// if operator has no group by clause.
   130  	if len(ap.Exprs) == 0 {
   131  		// if operator has no group by clause.
   132  		return ap.ctr.processWithoutGroup(ap, proc, anal, arg.GetIsFirst(), arg.GetIsLast())
   133  	}
   134  	return ap.ctr.processWithGroup(ap, proc, anal, arg.GetIsFirst(), arg.GetIsLast())
   135  }
   136  
   137  func (ctr *container) generateAggStructures(proc *process.Process, arg *Argument) error {
   138  	for i, ag := range arg.Aggs {
   139  		ctr.bat.Aggs[i] = aggexec.MakeAgg(
   140  			proc,
   141  			ag.GetAggID(), ag.IsDistinct(), ctr.aggVecs[i].Typ...)
   142  
   143  		if config := ag.GetExtraConfig(); config != nil {
   144  			if err := ctr.bat.Aggs[i].SetExtraInformation(config, 0); err != nil {
   145  				return err
   146  			}
   147  		}
   148  	}
   149  
   150  	if preAllocate := int(arg.PreAllocSize); preAllocate > 0 {
   151  		for _, ag := range ctr.bat.Aggs {
   152  			if err := ag.PreAllocateGroups(preAllocate); err != nil {
   153  				return err
   154  			}
   155  		}
   156  	}
   157  	return nil
   158  }
   159  
   160  func (ctr *container) processWithoutGroup(ap *Argument, proc *process.Process, anal process.Analyze, isFirst bool, isLast bool) (vm.CallResult, error) {
   161  	if ctr.state == vm.Build {
   162  		for {
   163  			result, err := vm.ChildrenCall(ap.GetChildren(0), proc, anal)
   164  			if err != nil {
   165  				return result, err
   166  			}
   167  			if result.Batch == nil {
   168  				ctr.state = vm.Eval
   169  				break
   170  			}
   171  			if result.Batch.IsEmpty() {
   172  				continue
   173  			}
   174  			bat := result.Batch
   175  			anal.Input(bat, isFirst)
   176  
   177  			if err = ctr.evalAggVector(bat, proc); err != nil {
   178  				return result, err
   179  			}
   180  
   181  			if ctr.bat == nil {
   182  				if err = initCtrBatchForProcessWithoutGroup(ap, proc, ctr); err != nil {
   183  					return result, err
   184  				}
   185  			}
   186  
   187  			if err = ctr.processH0(); err != nil {
   188  				return result, err
   189  			}
   190  		}
   191  	}
   192  
   193  	result := vm.NewCallResult()
   194  	if ctr.state == vm.Eval {
   195  
   196  		// the result of Agg can't be empty but 0 or NULL.
   197  		if !ctr.hasAggResult {
   198  			// very bad code.
   199  			if err := initCtrBatchForProcessWithoutGroup(ap, proc, ctr); err != nil {
   200  				return result, err
   201  			}
   202  		}
   203  		if ctr.bat != nil {
   204  			anal.Alloc(int64(ctr.bat.Size()))
   205  			anal.Output(ctr.bat, isLast)
   206  		}
   207  
   208  		result.Batch = ctr.bat
   209  		ctr.state = vm.End
   210  		return result, nil
   211  	}
   212  
   213  	if ctr.state == vm.End {
   214  		return result, nil
   215  	}
   216  
   217  	panic("bug")
   218  }
   219  
   220  func initCtrBatchForProcessWithoutGroup(ap *Argument, proc *process.Process, ctr *container) (err error) {
   221  	ctr.bat = batch.NewWithSize(0)
   222  	ctr.bat.SetRowCount(1)
   223  
   224  	ctr.bat.Aggs = make([]aggexec.AggFuncExec, len(ap.Aggs))
   225  	if err = ctr.generateAggStructures(proc, ap); err != nil {
   226  		return err
   227  	}
   228  	for _, ag := range ctr.bat.Aggs {
   229  		if err = ag.GroupGrow(1); err != nil {
   230  			return err
   231  		}
   232  	}
   233  	return err
   234  }
   235  
   236  func (ctr *container) processWithGroup(ap *Argument, proc *process.Process, anal process.Analyze, isFirst bool, isLast bool) (vm.CallResult, error) {
   237  	if ctr.state == vm.Build {
   238  		for {
   239  			result, err := vm.ChildrenCall(ap.GetChildren(0), proc, anal)
   240  			if err != nil {
   241  				return result, err
   242  			}
   243  			if result.Batch == nil {
   244  				ctr.state = vm.Eval
   245  				break
   246  			}
   247  			if result.Batch.IsEmpty() {
   248  				continue
   249  			}
   250  			bat := result.Batch
   251  			// defer bat.Clean(proc.Mp())
   252  			anal.Input(bat, isFirst)
   253  
   254  			if err = ctr.evalAggVector(bat, proc); err != nil {
   255  				return result, err
   256  			}
   257  
   258  			for i := range ap.Exprs {
   259  				ctr.groupVecs.Vec[i], err = ctr.groupVecs.Executor[i].Eval(proc, []*batch.Batch{bat})
   260  				if err != nil {
   261  					return result, err
   262  				}
   263  			}
   264  
   265  			if ctr.bat == nil {
   266  				ctr.bat = batch.NewWithSize(len(ap.Exprs))
   267  				for i, vec := range ctr.groupVecs.Vec {
   268  					ctr.bat.Vecs[i] = proc.GetVector(*vec.GetType())
   269  				}
   270  				if ap.PreAllocSize > 0 {
   271  					err = ctr.bat.PreExtend(proc.Mp(), int(ap.PreAllocSize))
   272  					if err != nil {
   273  						return result, err
   274  					}
   275  				}
   276  				ctr.bat.Aggs = make([]aggexec.AggFuncExec, len(ap.Aggs))
   277  				if err = ctr.generateAggStructures(proc, ap); err != nil {
   278  					return result, err
   279  				}
   280  				switch {
   281  				//case ctr.idx != nil:
   282  				//	ctr.typ = HIndex
   283  				case ctr.keyWidth <= 8:
   284  					ctr.typ = H8
   285  					if ctr.intHashMap, err = hashmap.NewIntHashMap(ctr.groupVecsNullable, ap.Ibucket, ap.Nbucket, proc.Mp()); err != nil {
   286  						return result, err
   287  					}
   288  					if ap.PreAllocSize > 0 {
   289  						err = ctr.intHashMap.PreAlloc(ap.PreAllocSize, proc.Mp())
   290  						if err != nil {
   291  							return result, err
   292  						}
   293  					}
   294  				default:
   295  					ctr.typ = HStr
   296  					if ctr.strHashMap, err = hashmap.NewStrMap(ctr.groupVecsNullable, ap.Ibucket, ap.Nbucket, proc.Mp()); err != nil {
   297  						return result, err
   298  					}
   299  					if ap.PreAllocSize > 0 {
   300  						err = ctr.strHashMap.PreAlloc(ap.PreAllocSize, proc.Mp())
   301  						if err != nil {
   302  							return result, err
   303  						}
   304  					}
   305  				}
   306  			}
   307  
   308  			switch ctr.typ {
   309  			case H8:
   310  				err = ctr.processH8(bat, proc)
   311  			case HStr:
   312  				err = ctr.processHStr(bat, proc)
   313  			default:
   314  			}
   315  			if err != nil {
   316  				return result, err
   317  			}
   318  		}
   319  	}
   320  
   321  	result := vm.NewCallResult()
   322  	if ctr.state == vm.Eval {
   323  		if ctr.bat != nil {
   324  			if ap.NeedEval {
   325  				for i, ag := range ctr.bat.Aggs {
   326  					vec, err := ag.Flush()
   327  					if err != nil {
   328  						return result, err
   329  					}
   330  					ctr.bat.Aggs[i] = nil
   331  					ctr.bat.Vecs = append(ctr.bat.Vecs, vec)
   332  					anal.Alloc(int64(vec.Size()))
   333  
   334  					ag.Free()
   335  				}
   336  				ctr.bat.Aggs = nil
   337  			}
   338  			anal.Output(ctr.bat, isLast)
   339  		}
   340  
   341  		result.Batch = ctr.bat
   342  		ctr.state = vm.End
   343  		return result, nil
   344  	}
   345  
   346  	if ctr.state == vm.End {
   347  		return result, nil
   348  	}
   349  
   350  	panic("bug")
   351  }
   352  
   353  // processH8 use whole batch to fill the aggregation.
   354  func (ctr *container) processH0() error {
   355  	ctr.bat.SetRowCount(1)
   356  
   357  	for i, ag := range ctr.bat.Aggs {
   358  		err := ag.BulkFill(0, ctr.aggVecs[i].Vec)
   359  		if err != nil {
   360  			return err
   361  		}
   362  	}
   363  
   364  	return nil
   365  }
   366  
   367  // processH8 do group by aggregation with int hashmap.
   368  func (ctr *container) processH8(bat *batch.Batch, proc *process.Process) error {
   369  	count := bat.RowCount()
   370  	itr := ctr.intHashMap.NewIterator()
   371  	for i := 0; i < count; i += hashmap.UnitLimit {
   372  		if i%(hashmap.UnitLimit*32) == 0 {
   373  			runtime.Gosched()
   374  		}
   375  		n := count - i
   376  		if n > hashmap.UnitLimit {
   377  			n = hashmap.UnitLimit
   378  		}
   379  		rows := ctr.intHashMap.GroupCount()
   380  		vals, _, err := itr.Insert(i, n, ctr.groupVecs.Vec)
   381  		if err != nil {
   382  			return err
   383  		}
   384  		if err = ctr.batchFill(i, n, vals, rows, proc); err != nil {
   385  			return err
   386  		}
   387  	}
   388  	return nil
   389  }
   390  
   391  // processHStr do group by aggregation with string hashmap.
   392  func (ctr *container) processHStr(bat *batch.Batch, proc *process.Process) error {
   393  	count := bat.RowCount()
   394  	itr := ctr.strHashMap.NewIterator()
   395  	for i := 0; i < count; i += hashmap.UnitLimit { // batch
   396  		if i%(hashmap.UnitLimit*32) == 0 {
   397  			runtime.Gosched()
   398  		}
   399  		n := count - i
   400  		if n > hashmap.UnitLimit {
   401  			n = hashmap.UnitLimit
   402  		}
   403  		rows := ctr.strHashMap.GroupCount()
   404  		vals, _, err := itr.Insert(i, n, ctr.groupVecs.Vec)
   405  		if err != nil {
   406  			return err
   407  		}
   408  		if err = ctr.batchFill(i, n, vals, rows, proc); err != nil {
   409  			return err
   410  		}
   411  	}
   412  	return nil
   413  }
   414  
   415  func (ctr *container) batchFill(i int, n int, vals []uint64, hashRows uint64, proc *process.Process) error {
   416  	cnt := 0
   417  	valCnt := 0
   418  	copy(ctr.inserted[:n], ctr.zInserted[:n])
   419  	for k, v := range vals[:n] {
   420  		if v == 0 {
   421  			continue
   422  		}
   423  		if v > hashRows {
   424  			ctr.inserted[k] = 1
   425  			hashRows++
   426  			cnt++
   427  		}
   428  		valCnt++
   429  	}
   430  	ctr.bat.AddRowCount(cnt)
   431  
   432  	if cnt > 0 {
   433  		for j, vec := range ctr.bat.Vecs {
   434  			if err := vec.UnionBatch(ctr.groupVecs.Vec[j], int64(i), cnt, ctr.inserted[:n], proc.Mp()); err != nil {
   435  				return err
   436  			}
   437  		}
   438  		for _, ag := range ctr.bat.Aggs {
   439  			if err := ag.GroupGrow(cnt); err != nil {
   440  				return err
   441  			}
   442  		}
   443  	}
   444  	if valCnt == 0 {
   445  		return nil
   446  	}
   447  	for j, ag := range ctr.bat.Aggs {
   448  		err := ag.BatchFill(i, vals[:n], ctr.aggVecs[j].Vec)
   449  		if err != nil {
   450  			return err
   451  		}
   452  	}
   453  	return nil
   454  }
   455  
   456  func (ctr *container) evalAggVector(bat *batch.Batch, proc *process.Process) (err error) {
   457  	ctr.hasAggResult = true
   458  	input := []*batch.Batch{bat}
   459  
   460  	for i := range ctr.aggVecs {
   461  		for j := range ctr.aggVecs[i].Executor {
   462  			ctr.aggVecs[i].Vec[j], err = ctr.aggVecs[i].Executor[j].Eval(proc, input)
   463  			if err != nil {
   464  				return err
   465  			}
   466  		}
   467  	}
   468  	return nil
   469  }