github.com/matrixorigin/matrixone@v0.7.0/pkg/sql/colexec/group/group.go (about)

     1  // Copyright 2021 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package group
    16  
    17  import (
    18  	"bytes"
    19  	"fmt"
    20  
    21  	"github.com/matrixorigin/matrixone/pkg/common/hashmap"
    22  	"github.com/matrixorigin/matrixone/pkg/container/batch"
    23  	"github.com/matrixorigin/matrixone/pkg/container/index"
    24  	"github.com/matrixorigin/matrixone/pkg/container/vector"
    25  	"github.com/matrixorigin/matrixone/pkg/sql/colexec"
    26  	"github.com/matrixorigin/matrixone/pkg/sql/colexec/agg"
    27  	"github.com/matrixorigin/matrixone/pkg/sql/colexec/multi_col/group_concat"
    28  	"github.com/matrixorigin/matrixone/pkg/vm/process"
    29  )
    30  
    31  func String(arg any, buf *bytes.Buffer) {
    32  	ap := arg.(*Argument)
    33  	buf.WriteString("group([")
    34  	for i, expr := range ap.Exprs {
    35  		if i > 0 {
    36  			buf.WriteString(", ")
    37  		}
    38  		buf.WriteString(fmt.Sprintf("%v", expr))
    39  	}
    40  	buf.WriteString("], [")
    41  	for i, ag := range ap.Aggs {
    42  		if i > 0 {
    43  			buf.WriteString(", ")
    44  		}
    45  		buf.WriteString(fmt.Sprintf("%v(%v)", agg.Names[ag.Op], ag.E))
    46  	}
    47  	if len(ap.MultiAggs) != 0 {
    48  		if len(ap.Aggs) > 0 {
    49  			buf.WriteString(",")
    50  		}
    51  		for i, ag := range ap.MultiAggs {
    52  			if i > 0 {
    53  				buf.WriteString(",")
    54  			}
    55  			buf.WriteString("group_concat(")
    56  			for _, expr := range ag.GroupExpr {
    57  				buf.WriteString(fmt.Sprintf("%v ", expr))
    58  			}
    59  			buf.WriteString(")")
    60  		}
    61  	}
    62  	buf.WriteString("])")
    63  }
    64  
    65  func Prepare(_ *process.Process, arg any) error {
    66  	ap := arg.(*Argument)
    67  	ap.ctr = new(container)
    68  	ap.ctr.inserted = make([]uint8, hashmap.UnitLimit)
    69  	ap.ctr.zInserted = make([]uint8, hashmap.UnitLimit)
    70  	return nil
    71  }
    72  
    73  func Call(idx int, proc *process.Process, arg any, isFirst bool, isLast bool) (bool, error) {
    74  	var end bool
    75  	var err error
    76  	ap := arg.(*Argument)
    77  	anal := proc.GetAnalyze(idx)
    78  	anal.Start()
    79  	defer anal.Stop()
    80  
    81  	if len(ap.Exprs) == 0 {
    82  		end, err = ap.ctr.process(ap, proc, anal, isFirst, isLast)
    83  	} else {
    84  		end, err = ap.ctr.processWithGroup(ap, proc, anal, isFirst, isLast)
    85  	}
    86  	if err != nil {
    87  		ap.Free(proc, true)
    88  	}
    89  	if end {
    90  		ap.Free(proc, false)
    91  	}
    92  	return end, err
    93  }
    94  
    95  func (ctr *container) process(ap *Argument, proc *process.Process, anal process.Analyze, isFirst bool, isLast bool) (bool, error) {
    96  	bat := proc.InputBatch()
    97  	if bat == nil {
    98  		// if the result vectors are empty, process again. because the result of Agg can't be empty but 0 or NULL.
    99  		if len(ctr.aggVecs) == 0 && len(ctr.multiVecs) == 0 {
   100  			b := batch.NewWithSize(len(ap.Types))
   101  			for i := range b.Vecs {
   102  				b.Vecs[i] = vector.New(ap.Types[i])
   103  			}
   104  			proc.SetInputBatch(b)
   105  			if _, err := ctr.process(ap, proc, anal, isFirst, isLast); err != nil {
   106  				return false, err
   107  			}
   108  		}
   109  		if ctr.bat != nil {
   110  			ctr.bat.ExpandNulls()
   111  			anal.Alloc(int64(ctr.bat.Size()))
   112  			anal.Output(ctr.bat, isLast)
   113  			proc.SetInputBatch(ctr.bat)
   114  			ctr.bat = nil
   115  			return true, nil
   116  		}
   117  		proc.SetInputBatch(nil)
   118  		return true, nil
   119  	}
   120  	defer bat.Clean(proc.Mp())
   121  	if len(bat.Vecs) == 0 {
   122  		return false, nil
   123  	}
   124  	anal.Input(bat, isFirst)
   125  	proc.SetInputBatch(&batch.Batch{})
   126  	if len(ctr.aggVecs) == 0 {
   127  		ctr.aggVecs = make([]evalVector, len(ap.Aggs))
   128  	}
   129  
   130  	if err := ctr.evalAggVector(bat, ap.Aggs, proc, anal); err != nil {
   131  		return false, err
   132  	}
   133  	defer ctr.cleanAggVectors(proc.Mp())
   134  
   135  	if len(ctr.multiVecs) == 0 {
   136  		ctr.multiVecs = make([][]evalVector, len(ap.MultiAggs))
   137  		for i, agg := range ap.MultiAggs {
   138  			ctr.multiVecs[i] = make([]evalVector, len(agg.GroupExpr))
   139  		}
   140  	}
   141  	if err := ctr.evalMultiAggs(bat, ap.MultiAggs, proc, anal); err != nil {
   142  		return false, err
   143  	}
   144  	defer ctr.cleanMultiAggVecs(proc.Mp())
   145  
   146  	if ctr.bat == nil {
   147  		var err error
   148  
   149  		ctr.bat = batch.NewWithSize(0)
   150  		ctr.bat.Zs = proc.Mp().GetSels()
   151  		ctr.bat.Zs = append(ctr.bat.Zs, 0)
   152  		ctr.bat.Aggs = make([]agg.Agg[any], len(ap.Aggs)+len(ap.MultiAggs))
   153  		for i, ag := range ap.Aggs {
   154  			if ctr.bat.Aggs[i], err = agg.New(ag.Op, ag.Dist, ctr.aggVecs[i].vec.Typ); err != nil {
   155  				ctr.bat = nil
   156  				return false, err
   157  			}
   158  		}
   159  		for i, agg := range ap.MultiAggs {
   160  			if ctr.bat.Aggs[i+len(ap.Aggs)] = group_concat.NewGroupConcat(&agg, ctr.ToInputType(i)); err != nil {
   161  				return false, err
   162  			}
   163  		}
   164  		for _, ag := range ctr.bat.Aggs {
   165  			if err := ag.Grows(1, proc.Mp()); err != nil {
   166  				return false, err
   167  			}
   168  		}
   169  	}
   170  	if bat.Length() == 0 {
   171  		return false, nil
   172  	}
   173  	if err := ctr.processH0(bat, ap, proc); err != nil {
   174  		return false, err
   175  	}
   176  	return false, nil
   177  }
   178  
   179  func (ctr *container) processWithGroup(ap *Argument, proc *process.Process, anal process.Analyze, isFirst bool, isLast bool) (bool, error) {
   180  	var err error
   181  
   182  	bat := proc.InputBatch()
   183  	if bat == nil {
   184  		if ctr.bat != nil {
   185  			if ap.NeedEval {
   186  				for i, ag := range ctr.bat.Aggs {
   187  					vec, err := ag.Eval(proc.Mp())
   188  					if err != nil {
   189  						return false, err
   190  					}
   191  					ctr.bat.Aggs[i] = nil
   192  					ctr.bat.Vecs = append(ctr.bat.Vecs, vec)
   193  					anal.Alloc(int64(vec.Size()))
   194  				}
   195  				ctr.bat.Aggs = nil
   196  				for i := range ctr.bat.Zs { // reset zs
   197  					ctr.bat.Zs[i] = 1
   198  				}
   199  			}
   200  			ctr.bat.ExpandNulls()
   201  			anal.Output(ctr.bat, isLast)
   202  			proc.SetInputBatch(ctr.bat)
   203  			ctr.bat = nil
   204  			return true, nil
   205  		}
   206  		proc.SetInputBatch(nil)
   207  		return true, nil
   208  	}
   209  	if bat.Length() == 0 {
   210  		return false, nil
   211  	}
   212  	defer bat.Clean(proc.Mp())
   213  	anal.Input(bat, isFirst)
   214  	proc.SetInputBatch(&batch.Batch{})
   215  	if len(ctr.aggVecs) == 0 {
   216  		ctr.aggVecs = make([]evalVector, len(ap.Aggs))
   217  	}
   218  
   219  	if err := ctr.evalAggVector(bat, ap.Aggs, proc, anal); err != nil {
   220  		return false, err
   221  	}
   222  	defer ctr.cleanAggVectors(proc.Mp())
   223  
   224  	if len(ctr.multiVecs) == 0 {
   225  		ctr.multiVecs = make([][]evalVector, len(ap.MultiAggs))
   226  		for i, agg := range ap.MultiAggs {
   227  			ctr.multiVecs[i] = make([]evalVector, len(agg.GroupExpr))
   228  		}
   229  	}
   230  	if err := ctr.evalMultiAggs(bat, ap.MultiAggs, proc, anal); err != nil {
   231  		return false, err
   232  	}
   233  	defer ctr.cleanMultiAggVecs(proc.Mp())
   234  	if len(ctr.groupVecs) == 0 {
   235  		ctr.vecs = make([]*vector.Vector, len(ap.Exprs))
   236  		ctr.groupVecs = make([]evalVector, len(ap.Exprs))
   237  	}
   238  
   239  	for i, expr := range ap.Exprs {
   240  		vec, err := colexec.EvalExpr(bat, proc, expr)
   241  		if err != nil || vec.ConstExpand(false, proc.Mp()) == nil {
   242  			ctr.cleanGroupVectors(proc.Mp())
   243  			return false, err
   244  		}
   245  		ctr.groupVecs[i].vec = vec
   246  		ctr.groupVecs[i].needFree = true
   247  		for j := range bat.Vecs {
   248  			if bat.Vecs[j] == vec {
   249  				ctr.groupVecs[i].needFree = false
   250  				break
   251  			}
   252  		}
   253  		if ctr.groupVecs[i].needFree && vec != nil {
   254  			anal.Alloc(int64(vec.Size()))
   255  		}
   256  		ctr.vecs[i] = vec
   257  	}
   258  
   259  	if len(ctr.groupVecs) == 1 && !ctr.groupVecs[0].needFree {
   260  		if ctr.groupVecs[0].vec.IsLowCardinality() {
   261  			ctr.idx = ctr.groupVecs[0].vec.Index().(*index.LowCardinalityIndex).Dup()
   262  		}
   263  	}
   264  
   265  	if ctr.bat == nil {
   266  		size := 0
   267  		ctr.bat = batch.NewWithSize(len(ap.Exprs))
   268  		ctr.bat.Zs = proc.Mp().GetSels()
   269  		for i := range ctr.groupVecs {
   270  			vec := ctr.groupVecs[i].vec
   271  			ctr.bat.Vecs[i] = vector.New(vec.Typ)
   272  			switch vec.Typ.TypeSize() {
   273  			case 1:
   274  				size += 1 + 1
   275  			case 2:
   276  				size += 2 + 1
   277  			case 4:
   278  				size += 4 + 1
   279  			case 8:
   280  				size += 8 + 1
   281  			case 16:
   282  				size += 16 + 1
   283  			default:
   284  				size = 128
   285  			}
   286  		}
   287  		ctr.bat.Aggs = make([]agg.Agg[any], len(ap.Aggs)+len(ap.MultiAggs))
   288  		for i, ag := range ap.Aggs {
   289  			if ctr.bat.Aggs[i], err = agg.New(ag.Op, ag.Dist, ctr.aggVecs[i].vec.Typ); err != nil {
   290  				return false, err
   291  			}
   292  		}
   293  		for i, agg := range ap.MultiAggs {
   294  			if ctr.bat.Aggs[i+len(ap.Aggs)] = group_concat.NewGroupConcat(&agg, ctr.ToInputType(i)); err != nil {
   295  				return false, err
   296  			}
   297  		}
   298  		switch {
   299  		case ctr.idx != nil:
   300  			ctr.typ = HIndex
   301  		case size <= 8:
   302  			ctr.typ = H8
   303  			if ctr.intHashMap, err = hashmap.NewIntHashMap(true, ap.Ibucket, ap.Nbucket, proc.Mp()); err != nil {
   304  				return false, err
   305  			}
   306  		default:
   307  			ctr.typ = HStr
   308  			if ctr.strHashMap, err = hashmap.NewStrMap(true, ap.Ibucket, ap.Nbucket, proc.Mp()); err != nil {
   309  				return false, err
   310  			}
   311  		}
   312  	}
   313  	switch ctr.typ {
   314  	case H8:
   315  		err = ctr.processH8(bat, proc)
   316  	case HStr:
   317  		err = ctr.processHStr(bat, proc)
   318  	default:
   319  		err = ctr.processHIndex(bat, proc)
   320  	}
   321  	if err != nil {
   322  		return false, err
   323  	}
   324  	return false, err
   325  }
   326  
   327  func (ctr *container) processH0(bat *batch.Batch, ap *Argument, proc *process.Process) error {
   328  	for _, z := range bat.Zs {
   329  		ctr.bat.Zs[0] += z
   330  	}
   331  	for i, ag := range ctr.bat.Aggs {
   332  		if i < len(ctr.aggVecs) {
   333  			err := ag.BulkFill(0, bat.Zs, []*vector.Vector{ctr.aggVecs[i].vec})
   334  			if err != nil {
   335  				return err
   336  			}
   337  		} else {
   338  			err := ag.BulkFill(0, bat.Zs, ctr.ToVecotrs(i-len(ctr.aggVecs)))
   339  			if err != nil {
   340  				return err
   341  			}
   342  		}
   343  	}
   344  	return nil
   345  }
   346  
   347  func (ctr *container) processH8(bat *batch.Batch, proc *process.Process) error {
   348  	count := bat.Length()
   349  	itr := ctr.intHashMap.NewIterator()
   350  	for i := 0; i < count; i += hashmap.UnitLimit {
   351  		n := count - i
   352  		if n > hashmap.UnitLimit {
   353  			n = hashmap.UnitLimit
   354  		}
   355  		rows := ctr.intHashMap.GroupCount()
   356  		vals, _, err := itr.Insert(i, n, ctr.vecs)
   357  		if err != nil {
   358  			return err
   359  		}
   360  		if err := ctr.batchFill(i, n, bat, vals, rows, proc); err != nil {
   361  			return err
   362  		}
   363  	}
   364  	return nil
   365  }
   366  
   367  func (ctr *container) processHStr(bat *batch.Batch, proc *process.Process) error {
   368  	count := bat.Length()
   369  	itr := ctr.strHashMap.NewIterator()
   370  	for i := 0; i < count; i += hashmap.UnitLimit { // batch
   371  		n := count - i
   372  		if n > hashmap.UnitLimit {
   373  			n = hashmap.UnitLimit
   374  		}
   375  		rows := ctr.strHashMap.GroupCount()
   376  		vals, _, err := itr.Insert(i, n, ctr.vecs)
   377  		if err != nil {
   378  			return err
   379  		}
   380  		if err := ctr.batchFill(i, n, bat, vals, rows, proc); err != nil {
   381  			return err
   382  		}
   383  	}
   384  	return nil
   385  }
   386  
   387  func (ctr *container) processHIndex(bat *batch.Batch, proc *process.Process) error {
   388  	mSels := make([][]int64, index.MaxLowCardinality+1)
   389  	poses := vector.MustTCols[uint16](ctr.idx.GetPoses())
   390  	for k, v := range poses {
   391  		if len(mSels[v]) == 0 {
   392  			mSels[v] = make([]int64, 0, 64)
   393  		}
   394  		mSels[v] = append(mSels[v], int64(k))
   395  	}
   396  	if len(mSels[0]) == 0 { // hasNotNull == true
   397  		mSels = mSels[1:]
   398  	}
   399  
   400  	var groups []int64
   401  	for i, sels := range mSels {
   402  		if len(sels) > 0 {
   403  			groups = append(groups, sels[0])
   404  			ctr.bat.Zs = append(ctr.bat.Zs, 0)
   405  			for _, k := range sels {
   406  				ctr.bat.Zs[i] += bat.Zs[k]
   407  			}
   408  		}
   409  	}
   410  
   411  	for _, ag := range ctr.bat.Aggs {
   412  		if err := ag.Grows(len(groups), proc.Mp()); err != nil {
   413  			return err
   414  		}
   415  	}
   416  	if err := vector.Union(ctr.bat.Vecs[0], ctr.vecs[0], groups, false, proc.Mp()); err != nil {
   417  		return err
   418  	}
   419  	for i, ag := range ctr.bat.Aggs {
   420  
   421  		for j, sels := range mSels {
   422  			for _, sel := range sels {
   423  				if i < len(ctr.aggVecs) {
   424  					aggVecs := []*vector.Vector{ctr.aggVecs[i].vec}
   425  					if err := ag.Fill(int64(j), sel, 1, aggVecs); err != nil {
   426  						return err
   427  					}
   428  				} else {
   429  					if err := ag.Fill(int64(j), sel, 1, ctr.ToVecotrs(i-len(ctr.aggVecs))); err != nil {
   430  						return err
   431  					}
   432  				}
   433  			}
   434  		}
   435  	}
   436  	return nil
   437  }
   438  
   439  func (ctr *container) batchFill(i int, n int, bat *batch.Batch, vals []uint64, hashRows uint64, proc *process.Process) error {
   440  	cnt := 0
   441  	valCnt := 0
   442  	copy(ctr.inserted[:n], ctr.zInserted[:n])
   443  	for k, v := range vals[:n] {
   444  		if v == 0 {
   445  			continue
   446  		}
   447  		if v > hashRows {
   448  			ctr.inserted[k] = 1
   449  			hashRows++
   450  			cnt++
   451  			ctr.bat.Zs = append(ctr.bat.Zs, 0)
   452  		}
   453  		valCnt++
   454  		ai := int64(v) - 1
   455  		ctr.bat.Zs[ai] += bat.Zs[i+k]
   456  	}
   457  	if cnt > 0 {
   458  		for j, vec := range ctr.bat.Vecs {
   459  			if err := vector.UnionBatch(vec, ctr.groupVecs[j].vec, int64(i), cnt, ctr.inserted[:n], proc.Mp()); err != nil {
   460  				return err
   461  			}
   462  		}
   463  		for _, ag := range ctr.bat.Aggs {
   464  			if err := ag.Grows(cnt, proc.Mp()); err != nil {
   465  				return err
   466  			}
   467  		}
   468  	}
   469  	if valCnt == 0 {
   470  		return nil
   471  	}
   472  	for j, ag := range ctr.bat.Aggs {
   473  		if j < len(ctr.aggVecs) {
   474  			err := ag.BatchFill(int64(i), ctr.inserted[:n], vals, bat.Zs, []*vector.Vector{ctr.aggVecs[j].vec})
   475  			if err != nil {
   476  				return err
   477  			}
   478  		} else {
   479  			err := ag.BatchFill(int64(i), ctr.inserted[:n], vals, bat.Zs, ctr.ToVecotrs(j-len(ctr.aggVecs)))
   480  			if err != nil {
   481  				return err
   482  			}
   483  		}
   484  	}
   485  	return nil
   486  }
   487  
   488  func (ctr *container) evalAggVector(bat *batch.Batch, aggs []agg.Aggregate, proc *process.Process, analyze process.Analyze) error {
   489  	for i, ag := range aggs {
   490  		vec, err := colexec.EvalExpr(bat, proc, ag.E)
   491  		if err != nil || vec.ConstExpand(false, proc.Mp()) == nil {
   492  			ctr.cleanAggVectors(proc.Mp())
   493  			return err
   494  		}
   495  		ctr.aggVecs[i].vec = vec
   496  		ctr.aggVecs[i].needFree = true
   497  		for j := range bat.Vecs {
   498  			if bat.Vecs[j] == vec {
   499  				ctr.aggVecs[i].needFree = false
   500  				break
   501  			}
   502  		}
   503  		if ctr.aggVecs[i].needFree && vec != nil {
   504  			analyze.Alloc(int64(vec.Size()))
   505  		}
   506  	}
   507  	return nil
   508  }
   509  
   510  func (ctr *container) evalMultiAggs(bat *batch.Batch, multiAggs []group_concat.Argument, proc *process.Process, analyze process.Analyze) error {
   511  	for i := range multiAggs {
   512  		for j, expr := range multiAggs[i].GroupExpr {
   513  			vec, err := colexec.EvalExpr(bat, proc, expr)
   514  			if err != nil || vec.ConstExpand(false, proc.Mp()) == nil {
   515  				ctr.cleanMultiAggVecs(proc.Mp())
   516  				return err
   517  			}
   518  			ctr.multiVecs[i][j].vec = vec
   519  			ctr.multiVecs[i][j].needFree = true
   520  			for k := range bat.Vecs {
   521  				if bat.Vecs[k] == vec {
   522  					ctr.multiVecs[i][j].needFree = false
   523  					break
   524  				}
   525  			}
   526  			if ctr.multiVecs[i][j].needFree && vec != nil {
   527  				analyze.Alloc(int64(vec.Size()))
   528  			}
   529  		}
   530  	}
   531  	return nil
   532  }