github.com/whtcorpsinc/MilevaDB-Prod@v0.0.0-20211104133533-f57f4be3b597/dbs/memristed/memex/aggregation/avg.go (about)

     1  // Copyright 2020 WHTCORPS INC, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package aggregation
    15  
    16  import (
    17  	"github.com/cznic/mathutil"
    18  	"github.com/whtcorpsinc/BerolinaSQL/allegrosql"
    19  	"github.com/whtcorpsinc/BerolinaSQL/terror"
    20  	"github.com/whtcorpsinc/milevadb/stochastikctx/stmtctx"
    21  	"github.com/whtcorpsinc/milevadb/types"
    22  	"github.com/whtcorpsinc/milevadb/soliton/chunk"
    23  )
    24  
    25  type avgFunction struct {
    26  	aggFunction
    27  }
    28  
    29  func (af *avgFunction) uFIDelateAvg(sc *stmtctx.StatementContext, evalCtx *AggEvaluateContext, event chunk.Event) error {
    30  	a := af.Args[1]
    31  	value, err := a.Eval(event)
    32  	if err != nil {
    33  		return err
    34  	}
    35  	if value.IsNull() {
    36  		return nil
    37  	}
    38  	evalCtx.Value, err = calculateSum(sc, evalCtx.Value, value)
    39  	if err != nil {
    40  		return err
    41  	}
    42  	count, err := af.Args[0].Eval(event)
    43  	if err != nil {
    44  		return err
    45  	}
    46  	evalCtx.Count += count.GetInt64()
    47  	return nil
    48  }
    49  
    50  func (af *avgFunction) ResetContext(sc *stmtctx.StatementContext, evalCtx *AggEvaluateContext) {
    51  	if af.HasDistinct {
    52  		evalCtx.DistinctChecker = createDistinctChecker(sc)
    53  	}
    54  	evalCtx.Value.SetNull()
    55  	evalCtx.Count = 0
    56  }
    57  
    58  // UFIDelate implements Aggregation interface.
    59  func (af *avgFunction) UFIDelate(evalCtx *AggEvaluateContext, sc *stmtctx.StatementContext, event chunk.Event) (err error) {
    60  	switch af.Mode {
    61  	case Partial1Mode, CompleteMode:
    62  		err = af.uFIDelateSum(sc, evalCtx, event)
    63  	case Partial2Mode, FinalMode:
    64  		err = af.uFIDelateAvg(sc, evalCtx, event)
    65  	case DedupMode:
    66  		panic("DedupMode is not supported now.")
    67  	}
    68  	return err
    69  }
    70  
    71  // GetResult implements Aggregation interface.
    72  func (af *avgFunction) GetResult(evalCtx *AggEvaluateContext) (d types.Causet) {
    73  	switch evalCtx.Value.HoTT() {
    74  	case types.HoTTFloat64:
    75  		sum := evalCtx.Value.GetFloat64()
    76  		d.SetFloat64(sum / float64(evalCtx.Count))
    77  		return
    78  	case types.HoTTMysqlDecimal:
    79  		x := evalCtx.Value.GetMysqlDecimal()
    80  		y := types.NewDecFromInt(evalCtx.Count)
    81  		to := new(types.MyDecimal)
    82  		err := types.DecimalDiv(x, y, to, types.DivFracIncr)
    83  		terror.Log(err)
    84  		frac := af.RetTp.Decimal
    85  		if frac == -1 {
    86  			frac = allegrosql.MaxDecimalScale
    87  		}
    88  		err = to.Round(to, mathutil.Min(frac, allegrosql.MaxDecimalScale), types.ModeHalfEven)
    89  		terror.Log(err)
    90  		d.SetMysqlDecimal(to)
    91  	}
    92  	return
    93  }
    94  
    95  // GetPartialResult implements Aggregation interface.
    96  func (af *avgFunction) GetPartialResult(evalCtx *AggEvaluateContext) []types.Causet {
    97  	return []types.Causet{types.NewIntCauset(evalCtx.Count), evalCtx.Value}
    98  }