github.com/whtcorpsinc/milevadb-prod@v0.0.0-20211104133533-f57f4be3b597/causetstore/causetstore/mockeinsteindb/aggregate.go (about)

     1  // Copyright 2020 WHTCORPS INC, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package mockeinsteindb
    15  
    16  import (
    17  	"context"
    18  	"time"
    19  
    20  	"github.com/whtcorpsinc/errors"
    21  	"github.com/whtcorpsinc/milevadb/memex"
    22  	"github.com/whtcorpsinc/milevadb/memex/aggregation"
    23  	"github.com/whtcorpsinc/milevadb/types"
    24  	"github.com/whtcorpsinc/milevadb/soliton/chunk"
    25  	"github.com/whtcorpsinc/milevadb/soliton/codec"
    26  )
    27  
    28  type aggCtxsMapper map[string][]*aggregation.AggEvaluateContext
    29  
    30  var (
    31  	_ interlock = &hashAggInterDirc{}
    32  	_ interlock = &streamAggInterDirc{}
    33  )
    34  
    35  type hashAggInterDirc struct {
    36  	evalCtx           *evalContext
    37  	aggExprs          []aggregation.Aggregation
    38  	aggCtxsMap        aggCtxsMapper
    39  	groupByExprs      []memex.Expression
    40  	relatedDefCausOffsets []int
    41  	event               []types.Causet
    42  	groups            map[string]struct{}
    43  	groupKeys         [][]byte
    44  	groupKeyRows      [][][]byte
    45  	executed          bool
    46  	currGroupIdx      int
    47  	count             int64
    48  	execDetail        *execDetail
    49  
    50  	src interlock
    51  }
    52  
    53  func (e *hashAggInterDirc) InterDircDetails() []*execDetail {
    54  	var suffix []*execDetail
    55  	if e.src != nil {
    56  		suffix = e.src.InterDircDetails()
    57  	}
    58  	return append(suffix, e.execDetail)
    59  }
    60  
    61  func (e *hashAggInterDirc) SetSrcInterDirc(exec interlock) {
    62  	e.src = exec
    63  }
    64  
    65  func (e *hashAggInterDirc) GetSrcInterDirc() interlock {
    66  	return e.src
    67  }
    68  
    69  func (e *hashAggInterDirc) ResetCounts() {
    70  	e.src.ResetCounts()
    71  }
    72  
    73  func (e *hashAggInterDirc) Counts() []int64 {
    74  	return e.src.Counts()
    75  }
    76  
    77  func (e *hashAggInterDirc) innerNext(ctx context.Context) (bool, error) {
    78  	values, err := e.src.Next(ctx)
    79  	if err != nil {
    80  		return false, errors.Trace(err)
    81  	}
    82  	if values == nil {
    83  		return false, nil
    84  	}
    85  	err = e.aggregate(values)
    86  	if err != nil {
    87  		return false, errors.Trace(err)
    88  	}
    89  	return true, nil
    90  }
    91  
    92  func (e *hashAggInterDirc) Cursor() ([]byte, bool) {
    93  	panic("don't not use interlock streaming API for hash aggregation!")
    94  }
    95  
    96  func (e *hashAggInterDirc) Next(ctx context.Context) (value [][]byte, err error) {
    97  	defer func(begin time.Time) {
    98  		e.execDetail.uFIDelate(begin, value)
    99  	}(time.Now())
   100  	e.count++
   101  	if e.aggCtxsMap == nil {
   102  		e.aggCtxsMap = make(aggCtxsMapper)
   103  	}
   104  	if !e.executed {
   105  		for {
   106  			hasMore, err := e.innerNext(ctx)
   107  			if err != nil {
   108  				return nil, errors.Trace(err)
   109  			}
   110  			if !hasMore {
   111  				break
   112  			}
   113  		}
   114  		e.executed = true
   115  	}
   116  
   117  	if e.currGroupIdx >= len(e.groups) {
   118  		return nil, nil
   119  	}
   120  	gk := e.groupKeys[e.currGroupIdx]
   121  	value = make([][]byte, 0, len(e.groupByExprs)+2*len(e.aggExprs))
   122  	aggCtxs := e.getContexts(gk)
   123  	for i, agg := range e.aggExprs {
   124  		partialResults := agg.GetPartialResult(aggCtxs[i])
   125  		for _, result := range partialResults {
   126  			data, err := codec.EncodeValue(e.evalCtx.sc, nil, result)
   127  			if err != nil {
   128  				return nil, errors.Trace(err)
   129  			}
   130  			value = append(value, data)
   131  		}
   132  	}
   133  	value = append(value, e.groupKeyRows[e.currGroupIdx]...)
   134  	e.currGroupIdx++
   135  
   136  	return value, nil
   137  }
   138  
   139  func (e *hashAggInterDirc) getGroupKey() ([]byte, [][]byte, error) {
   140  	length := len(e.groupByExprs)
   141  	if length == 0 {
   142  		return nil, nil, nil
   143  	}
   144  	bufLen := 0
   145  	event := make([][]byte, 0, length)
   146  	for _, item := range e.groupByExprs {
   147  		v, err := item.Eval(chunk.MutRowFromCausets(e.event).ToRow())
   148  		if err != nil {
   149  			return nil, nil, errors.Trace(err)
   150  		}
   151  		b, err := codec.EncodeValue(e.evalCtx.sc, nil, v)
   152  		if err != nil {
   153  			return nil, nil, errors.Trace(err)
   154  		}
   155  		bufLen += len(b)
   156  		event = append(event, b)
   157  	}
   158  	buf := make([]byte, 0, bufLen)
   159  	for _, col := range event {
   160  		buf = append(buf, col...)
   161  	}
   162  	return buf, event, nil
   163  }
   164  
   165  // aggregate uFIDelates aggregate functions with event.
   166  func (e *hashAggInterDirc) aggregate(value [][]byte) error {
   167  	err := e.evalCtx.decodeRelatedDeferredCausetVals(e.relatedDefCausOffsets, value, e.event)
   168  	if err != nil {
   169  		return errors.Trace(err)
   170  	}
   171  	// Get group key.
   172  	gk, gbyKeyRow, err := e.getGroupKey()
   173  	if err != nil {
   174  		return errors.Trace(err)
   175  	}
   176  	if _, ok := e.groups[string(gk)]; !ok {
   177  		e.groups[string(gk)] = struct{}{}
   178  		e.groupKeys = append(e.groupKeys, gk)
   179  		e.groupKeyRows = append(e.groupKeyRows, gbyKeyRow)
   180  	}
   181  	// UFIDelate aggregate memexs.
   182  	aggCtxs := e.getContexts(gk)
   183  	for i, agg := range e.aggExprs {
   184  		err = agg.UFIDelate(aggCtxs[i], e.evalCtx.sc, chunk.MutRowFromCausets(e.event).ToRow())
   185  		if err != nil {
   186  			return errors.Trace(err)
   187  		}
   188  	}
   189  	return nil
   190  }
   191  
   192  func (e *hashAggInterDirc) getContexts(groupKey []byte) []*aggregation.AggEvaluateContext {
   193  	groupKeyString := string(groupKey)
   194  	aggCtxs, ok := e.aggCtxsMap[groupKeyString]
   195  	if !ok {
   196  		aggCtxs = make([]*aggregation.AggEvaluateContext, 0, len(e.aggExprs))
   197  		for _, agg := range e.aggExprs {
   198  			aggCtxs = append(aggCtxs, agg.CreateContext(e.evalCtx.sc))
   199  		}
   200  		e.aggCtxsMap[groupKeyString] = aggCtxs
   201  	}
   202  	return aggCtxs
   203  }
   204  
   205  type streamAggInterDirc struct {
   206  	evalCtx           *evalContext
   207  	aggExprs          []aggregation.Aggregation
   208  	aggCtxs           []*aggregation.AggEvaluateContext
   209  	groupByExprs      []memex.Expression
   210  	relatedDefCausOffsets []int
   211  	event               []types.Causet
   212  	tmpGroupByRow     []types.Causet
   213  	currGroupByRow    []types.Causet
   214  	nextGroupByRow    []types.Causet
   215  	currGroupByValues [][]byte
   216  	executed          bool
   217  	hasData           bool
   218  	count             int64
   219  	execDetail        *execDetail
   220  
   221  	src interlock
   222  }
   223  
   224  func (e *streamAggInterDirc) InterDircDetails() []*execDetail {
   225  	var suffix []*execDetail
   226  	if e.src != nil {
   227  		suffix = e.src.InterDircDetails()
   228  	}
   229  	return append(suffix, e.execDetail)
   230  }
   231  
   232  func (e *streamAggInterDirc) SetSrcInterDirc(exec interlock) {
   233  	e.src = exec
   234  }
   235  
   236  func (e *streamAggInterDirc) GetSrcInterDirc() interlock {
   237  	return e.src
   238  }
   239  
   240  func (e *streamAggInterDirc) ResetCounts() {
   241  	e.src.ResetCounts()
   242  }
   243  
   244  func (e *streamAggInterDirc) Counts() []int64 {
   245  	return e.src.Counts()
   246  }
   247  
   248  func (e *streamAggInterDirc) getPartialResult() ([][]byte, error) {
   249  	value := make([][]byte, 0, len(e.groupByExprs)+2*len(e.aggExprs))
   250  	for i, agg := range e.aggExprs {
   251  		partialResults := agg.GetPartialResult(e.aggCtxs[i])
   252  		for _, result := range partialResults {
   253  			data, err := codec.EncodeValue(e.evalCtx.sc, nil, result)
   254  			if err != nil {
   255  				return nil, errors.Trace(err)
   256  			}
   257  			value = append(value, data)
   258  		}
   259  		// Clear the aggregate context.
   260  		e.aggCtxs[i] = agg.CreateContext(e.evalCtx.sc)
   261  	}
   262  	e.currGroupByValues = e.currGroupByValues[:0]
   263  	for _, d := range e.currGroupByRow {
   264  		buf, err := codec.EncodeValue(e.evalCtx.sc, nil, d)
   265  		if err != nil {
   266  			return nil, errors.Trace(err)
   267  		}
   268  		e.currGroupByValues = append(e.currGroupByValues, buf)
   269  	}
   270  	e.currGroupByRow = types.CloneRow(e.nextGroupByRow)
   271  	return append(value, e.currGroupByValues...), nil
   272  }
   273  
   274  func (e *streamAggInterDirc) meetNewGroup(event [][]byte) (bool, error) {
   275  	if len(e.groupByExprs) == 0 {
   276  		return false, nil
   277  	}
   278  
   279  	e.tmpGroupByRow = e.tmpGroupByRow[:0]
   280  	matched, firstGroup := true, false
   281  	if e.nextGroupByRow == nil {
   282  		matched, firstGroup = false, true
   283  	}
   284  	for i, item := range e.groupByExprs {
   285  		d, err := item.Eval(chunk.MutRowFromCausets(e.event).ToRow())
   286  		if err != nil {
   287  			return false, errors.Trace(err)
   288  		}
   289  		if matched {
   290  			c, err := d.CompareCauset(e.evalCtx.sc, &e.nextGroupByRow[i])
   291  			if err != nil {
   292  				return false, errors.Trace(err)
   293  			}
   294  			matched = c == 0
   295  		}
   296  		e.tmpGroupByRow = append(e.tmpGroupByRow, d)
   297  	}
   298  	if firstGroup {
   299  		e.currGroupByRow = types.CloneRow(e.tmpGroupByRow)
   300  	}
   301  	if matched {
   302  		return false, nil
   303  	}
   304  	e.nextGroupByRow = e.tmpGroupByRow
   305  	return !firstGroup, nil
   306  }
   307  
   308  func (e *streamAggInterDirc) Cursor() ([]byte, bool) {
   309  	panic("don't not use interlock streaming API for stream aggregation!")
   310  }
   311  
   312  func (e *streamAggInterDirc) Next(ctx context.Context) (retRow [][]byte, err error) {
   313  	defer func(begin time.Time) {
   314  		e.execDetail.uFIDelate(begin, retRow)
   315  	}(time.Now())
   316  	e.count++
   317  	if e.executed {
   318  		return nil, nil
   319  	}
   320  
   321  	for {
   322  		values, err := e.src.Next(ctx)
   323  		if err != nil {
   324  			return nil, errors.Trace(err)
   325  		}
   326  		if values == nil {
   327  			e.executed = true
   328  			if !e.hasData && len(e.groupByExprs) > 0 {
   329  				return nil, nil
   330  			}
   331  			return e.getPartialResult()
   332  		}
   333  
   334  		e.hasData = true
   335  		err = e.evalCtx.decodeRelatedDeferredCausetVals(e.relatedDefCausOffsets, values, e.event)
   336  		if err != nil {
   337  			return nil, errors.Trace(err)
   338  		}
   339  		newGroup, err := e.meetNewGroup(values)
   340  		if err != nil {
   341  			return nil, errors.Trace(err)
   342  		}
   343  		if newGroup {
   344  			retRow, err = e.getPartialResult()
   345  			if err != nil {
   346  				return nil, errors.Trace(err)
   347  			}
   348  		}
   349  		for i, agg := range e.aggExprs {
   350  			err = agg.UFIDelate(e.aggCtxs[i], e.evalCtx.sc, chunk.MutRowFromCausets(e.event).ToRow())
   351  			if err != nil {
   352  				return nil, errors.Trace(err)
   353  			}
   354  		}
   355  		if newGroup {
   356  			return retRow, nil
   357  		}
   358  	}
   359  }