github.com/matrixorigin/matrixone@v0.7.0/pkg/sql/colexec/mergegroup/group.go (about)

     1  // Copyright 2021 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package mergegroup
    16  
    17  import (
    18  	"bytes"
    19  	"reflect"
    20  	"time"
    21  
    22  	"github.com/matrixorigin/matrixone/pkg/common/hashmap"
    23  	"github.com/matrixorigin/matrixone/pkg/common/moerr"
    24  	"github.com/matrixorigin/matrixone/pkg/container/batch"
    25  	"github.com/matrixorigin/matrixone/pkg/container/vector"
    26  	"github.com/matrixorigin/matrixone/pkg/vm/process"
    27  )
    28  
    29  func String(_ interface{}, buf *bytes.Buffer) {
    30  	buf.WriteString("mergeroup()")
    31  }
    32  
    33  func Prepare(proc *process.Process, arg interface{}) error {
    34  	ap := arg.(*Argument)
    35  	ap.ctr = new(container)
    36  	ap.ctr.inserted = make([]uint8, hashmap.UnitLimit)
    37  	ap.ctr.zInserted = make([]uint8, hashmap.UnitLimit)
    38  
    39  	ap.ctr.receiverListener = make([]reflect.SelectCase, len(proc.Reg.MergeReceivers))
    40  	for i, mr := range proc.Reg.MergeReceivers {
    41  		ap.ctr.receiverListener[i] = reflect.SelectCase{
    42  			Dir:  reflect.SelectRecv,
    43  			Chan: reflect.ValueOf(mr.Ch),
    44  		}
    45  	}
    46  	ap.ctr.aliveMergeReceiver = len(proc.Reg.MergeReceivers)
    47  	return nil
    48  }
    49  
    50  func Call(idx int, proc *process.Process, arg interface{}, isFirst bool, isLast bool) (bool, error) {
    51  	ap := arg.(*Argument)
    52  	ctr := ap.ctr
    53  	anal := proc.GetAnalyze(idx)
    54  	anal.Start()
    55  	defer anal.Stop()
    56  
    57  	for {
    58  		switch ctr.state {
    59  		case Build:
    60  			if err := ctr.build(proc, anal, isFirst); err != nil {
    61  				return false, err
    62  			}
    63  			ctr.state = Eval
    64  		case Eval:
    65  			if ctr.bat != nil {
    66  				if ap.NeedEval {
    67  					for i, agg := range ctr.bat.Aggs {
    68  						vec, err := agg.Eval(proc.Mp())
    69  						if err != nil {
    70  							ctr.state = End
    71  							return false, err
    72  						}
    73  						ctr.bat.Aggs[i] = nil
    74  						ctr.bat.Vecs = append(ctr.bat.Vecs, vec)
    75  						if vec != nil {
    76  							anal.Alloc(int64(vec.Size()))
    77  						}
    78  					}
    79  					ctr.bat.Aggs = nil
    80  					for i := range ctr.bat.Zs { // reset zs
    81  						ctr.bat.Zs[i] = 1
    82  					}
    83  				}
    84  				anal.Output(ctr.bat, isLast)
    85  				ctr.bat.ExpandNulls()
    86  			}
    87  			ctr.state = End
    88  		case End:
    89  			proc.SetInputBatch(ctr.bat)
    90  			ctr.bat = nil
    91  			ap.Free(proc, false)
    92  			return true, nil
    93  		}
    94  	}
    95  }
    96  
    97  func (ctr *container) build(proc *process.Process, anal process.Analyze, isFirst bool) error {
    98  	var err error
    99  	for {
   100  		if ctr.aliveMergeReceiver == 0 {
   101  			return nil
   102  		}
   103  
   104  		start := time.Now()
   105  		chosen, value, ok := reflect.Select(ctr.receiverListener)
   106  		if !ok {
   107  			return moerr.NewInternalError(proc.Ctx, "pipeline closed unexpectedly")
   108  		}
   109  		anal.WaitStop(start)
   110  
   111  		pointer := value.UnsafePointer()
   112  		bat := (*batch.Batch)(pointer)
   113  		if bat == nil {
   114  			ctr.receiverListener = append(ctr.receiverListener[:chosen], ctr.receiverListener[chosen+1:]...)
   115  			ctr.aliveMergeReceiver--
   116  			continue
   117  		}
   118  
   119  		if bat.Length() == 0 {
   120  			continue
   121  		}
   122  
   123  		anal.Input(bat, isFirst)
   124  		if err = ctr.process(bat, proc); err != nil {
   125  			bat.Clean(proc.Mp())
   126  			return err
   127  		}
   128  	}
   129  }
   130  
   131  func (ctr *container) process(bat *batch.Batch, proc *process.Process) error {
   132  	var err error
   133  
   134  	if ctr.bat == nil {
   135  		size := 0
   136  		for _, vec := range bat.Vecs {
   137  			switch vec.Typ.TypeSize() {
   138  			case 1:
   139  				size += 1 + 1
   140  			case 2:
   141  				size += 2 + 1
   142  			case 4:
   143  				size += 4 + 1
   144  			case 8:
   145  				size += 8 + 1
   146  			case 16:
   147  				size += 16 + 1
   148  			default:
   149  				size = 128
   150  			}
   151  		}
   152  		switch {
   153  		case size == 0:
   154  			ctr.typ = H0
   155  		case size <= 8:
   156  			ctr.typ = H8
   157  			if ctr.intHashMap, err = hashmap.NewIntHashMap(true, 0, 0, proc.Mp()); err != nil {
   158  				return err
   159  			}
   160  		default:
   161  			ctr.typ = HStr
   162  			if ctr.strHashMap, err = hashmap.NewStrMap(true, 0, 0, proc.Mp()); err != nil {
   163  				return err
   164  			}
   165  		}
   166  	}
   167  	switch ctr.typ {
   168  	case H0:
   169  		err = ctr.processH0(bat, proc)
   170  	case H8:
   171  		err = ctr.processH8(bat, proc)
   172  	default:
   173  		err = ctr.processHStr(bat, proc)
   174  	}
   175  	if err != nil {
   176  		return err
   177  	}
   178  	return nil
   179  }
   180  
   181  func (ctr *container) processH0(bat *batch.Batch, proc *process.Process) error {
   182  	if ctr.bat == nil {
   183  		ctr.bat = bat
   184  		return nil
   185  	}
   186  	defer bat.Clean(proc.Mp())
   187  	for _, z := range bat.Zs {
   188  		ctr.bat.Zs[0] += z
   189  	}
   190  	for i, agg := range ctr.bat.Aggs {
   191  		err := agg.Merge(bat.Aggs[i], 0, 0)
   192  		if err != nil {
   193  			return err
   194  		}
   195  	}
   196  	return nil
   197  }
   198  
   199  func (ctr *container) processH8(bat *batch.Batch, proc *process.Process) error {
   200  	count := bat.Length()
   201  	itr := ctr.intHashMap.NewIterator()
   202  	flg := ctr.bat == nil
   203  	if !flg {
   204  		defer bat.Clean(proc.Mp())
   205  	}
   206  	for i := 0; i < count; i += hashmap.UnitLimit {
   207  		n := count - i
   208  		if n > hashmap.UnitLimit {
   209  			n = hashmap.UnitLimit
   210  		}
   211  		rowCount := ctr.intHashMap.GroupCount()
   212  		vals, _, err := itr.Insert(i, n, bat.Vecs)
   213  		if err != nil {
   214  			return err
   215  		}
   216  		if !flg {
   217  			if err = ctr.batchFill(i, n, bat, vals, rowCount, proc); err != nil {
   218  				return err
   219  			}
   220  		}
   221  	}
   222  	if flg {
   223  		ctr.bat = bat
   224  	}
   225  	return nil
   226  }
   227  
   228  func (ctr *container) processHStr(bat *batch.Batch, proc *process.Process) error {
   229  	count := bat.Length()
   230  	itr := ctr.strHashMap.NewIterator()
   231  	flg := ctr.bat == nil
   232  	if !flg {
   233  		defer bat.Clean(proc.Mp())
   234  	}
   235  	for i := 0; i < count; i += hashmap.UnitLimit { // batch
   236  		n := count - i
   237  		if n > hashmap.UnitLimit {
   238  			n = hashmap.UnitLimit
   239  		}
   240  		rowCount := ctr.strHashMap.GroupCount()
   241  		vals, _, err := itr.Insert(i, n, bat.Vecs)
   242  		if err != nil {
   243  			return err
   244  		}
   245  		if !flg {
   246  			if err := ctr.batchFill(i, n, bat, vals, rowCount, proc); err != nil {
   247  				return err
   248  			}
   249  		}
   250  	}
   251  	if flg {
   252  		ctr.bat = bat
   253  	}
   254  	return nil
   255  }
   256  
   257  func (ctr *container) batchFill(i int, n int, bat *batch.Batch, vals []uint64, hashRows uint64, proc *process.Process) error {
   258  	cnt := 0
   259  	copy(ctr.inserted[:n], ctr.zInserted[:n])
   260  	for k, v := range vals {
   261  		if v > hashRows {
   262  			ctr.inserted[k] = 1
   263  			hashRows++
   264  			cnt++
   265  			ctr.bat.Zs = append(ctr.bat.Zs, 0)
   266  		}
   267  		ai := int64(v) - 1
   268  		ctr.bat.Zs[ai] += bat.Zs[i+k]
   269  	}
   270  	if cnt > 0 {
   271  		for j, vec := range ctr.bat.Vecs {
   272  			if err := vector.UnionBatch(vec, bat.Vecs[j], int64(i), cnt, ctr.inserted[:n], proc.Mp()); err != nil {
   273  				return err
   274  			}
   275  		}
   276  		for _, agg := range ctr.bat.Aggs {
   277  			if err := agg.Grows(cnt, proc.Mp()); err != nil {
   278  				return err
   279  			}
   280  		}
   281  	}
   282  	for j, agg := range ctr.bat.Aggs {
   283  		if err := agg.BatchMerge(bat.Aggs[j], int64(i), ctr.inserted[:n], vals); err != nil {
   284  			return err
   285  		}
   286  	}
   287  	return nil
   288  }