github.com/whtcorpsinc/milevadb-prod@v0.0.0-20211104133533-f57f4be3b597/dbs/memristed/memex/aggregation/base_func.go (about)

     1  // Copyright 2020 WHTCORPS INC, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package aggregation
    15  
    16  import (
    17  	"bytes"
    18  	"math"
    19  	"strings"
    20  
    21  	"github.com/cznic/mathutil"
    22  	"github.com/whtcorpsinc/errors"
    23  	"github.com/whtcorpsinc/BerolinaSQL/ast"
    24  	"github.com/whtcorpsinc/BerolinaSQL/charset"
    25  	"github.com/whtcorpsinc/BerolinaSQL/allegrosql"
    26  	"github.com/whtcorpsinc/milevadb/memex"
    27  	"github.com/whtcorpsinc/milevadb/stochastikctx"
    28  	"github.com/whtcorpsinc/milevadb/types"
    29  )
    30  
    31  // baseFuncDesc describes an function signature, only used in causet.
    32  type baseFuncDesc struct {
    33  	// Name represents the function name.
    34  	Name string
    35  	// Args represents the arguments of the function.
    36  	Args []memex.Expression
    37  	// RetTp represents the return type of the function.
    38  	RetTp *types.FieldType
    39  }
    40  
    41  func newBaseFuncDesc(ctx stochastikctx.Context, name string, args []memex.Expression) (baseFuncDesc, error) {
    42  	b := baseFuncDesc{Name: strings.ToLower(name), Args: args}
    43  	err := b.typeInfer(ctx)
    44  	return b, err
    45  }
    46  
    47  func (a *baseFuncDesc) equal(ctx stochastikctx.Context, other *baseFuncDesc) bool {
    48  	if a.Name != other.Name || len(a.Args) != len(other.Args) {
    49  		return false
    50  	}
    51  	for i := range a.Args {
    52  		if !a.Args[i].Equal(ctx, other.Args[i]) {
    53  			return false
    54  		}
    55  	}
    56  	return true
    57  }
    58  
    59  func (a *baseFuncDesc) clone() *baseFuncDesc {
    60  	clone := *a
    61  	newTp := *a.RetTp
    62  	clone.RetTp = &newTp
    63  	clone.Args = make([]memex.Expression, len(a.Args))
    64  	for i := range a.Args {
    65  		clone.Args[i] = a.Args[i].Clone()
    66  	}
    67  	return &clone
    68  }
    69  
    70  // String implements the fmt.Stringer interface.
    71  func (a *baseFuncDesc) String() string {
    72  	buffer := bytes.NewBufferString(a.Name)
    73  	buffer.WriteString("(")
    74  	for i, arg := range a.Args {
    75  		buffer.WriteString(arg.String())
    76  		if i+1 != len(a.Args) {
    77  			buffer.WriteString(", ")
    78  		}
    79  	}
    80  	buffer.WriteString(")")
    81  	return buffer.String()
    82  }
    83  
    84  // typeInfer infers the arguments and return types of an function.
    85  func (a *baseFuncDesc) typeInfer(ctx stochastikctx.Context) error {
    86  	switch a.Name {
    87  	case ast.AggFuncCount:
    88  		a.typeInfer4Count(ctx)
    89  	case ast.AggFuncApproxCountDistinct:
    90  		a.typeInfer4ApproxCountDistinct(ctx)
    91  	case ast.AggFuncSum:
    92  		a.typeInfer4Sum(ctx)
    93  	case ast.AggFuncAvg:
    94  		a.typeInfer4Avg(ctx)
    95  	case ast.AggFuncGroupConcat:
    96  		a.typeInfer4GroupConcat(ctx)
    97  	case ast.AggFuncMax, ast.AggFuncMin, ast.AggFuncFirstEvent,
    98  		ast.WindowFuncFirstValue, ast.WindowFuncLastValue, ast.WindowFuncNthValue:
    99  		a.typeInfer4MaxMin(ctx)
   100  	case ast.AggFuncBitAnd, ast.AggFuncBitOr, ast.AggFuncBitXor:
   101  		a.typeInfer4BitFuncs(ctx)
   102  	case ast.WindowFuncEventNumber, ast.WindowFuncRank, ast.WindowFuncDenseRank:
   103  		a.typeInfer4NumberFuncs()
   104  	case ast.WindowFuncCumeDist:
   105  		a.typeInfer4CumeDist()
   106  	case ast.WindowFuncNtile:
   107  		a.typeInfer4Ntile()
   108  	case ast.WindowFuncPercentRank:
   109  		a.typeInfer4PercentRank()
   110  	case ast.WindowFuncLead, ast.WindowFuncLag:
   111  		a.typeInfer4LeadLag(ctx)
   112  	case ast.AggFuncVarPop, ast.AggFuncStddevPop, ast.AggFuncVarSamp, ast.AggFuncStddevSamp:
   113  		a.typeInfer4PopOrSamp(ctx)
   114  	case ast.AggFuncJsonObjectAgg:
   115  		a.typeInfer4JsonFuncs(ctx)
   116  	default:
   117  		return errors.Errorf("unsupported agg function: %s", a.Name)
   118  	}
   119  	return nil
   120  }
   121  
   122  func (a *baseFuncDesc) typeInfer4Count(ctx stochastikctx.Context) {
   123  	a.RetTp = types.NewFieldType(allegrosql.TypeLonglong)
   124  	a.RetTp.Flen = 21
   125  	a.RetTp.Decimal = 0
   126  	// count never returns null
   127  	a.RetTp.Flag |= allegrosql.NotNullFlag
   128  	types.SetBinChsClnFlag(a.RetTp)
   129  }
   130  
   131  func (a *baseFuncDesc) typeInfer4ApproxCountDistinct(ctx stochastikctx.Context) {
   132  	a.typeInfer4Count(ctx)
   133  }
   134  
   135  // typeInfer4Sum should returns a "decimal", otherwise it returns a "double".
   136  // Because child returns integer or decimal type.
   137  func (a *baseFuncDesc) typeInfer4Sum(ctx stochastikctx.Context) {
   138  	switch a.Args[0].GetType().Tp {
   139  	case allegrosql.TypeTiny, allegrosql.TypeShort, allegrosql.TypeInt24, allegrosql.TypeLong, allegrosql.TypeLonglong, allegrosql.TypeYear:
   140  		a.RetTp = types.NewFieldType(allegrosql.TypeNewDecimal)
   141  		a.RetTp.Flen, a.RetTp.Decimal = allegrosql.MaxDecimalWidth, 0
   142  	case allegrosql.TypeNewDecimal:
   143  		a.RetTp = types.NewFieldType(allegrosql.TypeNewDecimal)
   144  		a.RetTp.Flen, a.RetTp.Decimal = allegrosql.MaxDecimalWidth, a.Args[0].GetType().Decimal
   145  		if a.RetTp.Decimal < 0 || a.RetTp.Decimal > allegrosql.MaxDecimalScale {
   146  			a.RetTp.Decimal = allegrosql.MaxDecimalScale
   147  		}
   148  	case allegrosql.TypeDouble, allegrosql.TypeFloat:
   149  		a.RetTp = types.NewFieldType(allegrosql.TypeDouble)
   150  		a.RetTp.Flen, a.RetTp.Decimal = allegrosql.MaxRealWidth, a.Args[0].GetType().Decimal
   151  	default:
   152  		a.RetTp = types.NewFieldType(allegrosql.TypeDouble)
   153  		a.RetTp.Flen, a.RetTp.Decimal = allegrosql.MaxRealWidth, types.UnspecifiedLength
   154  	}
   155  	types.SetBinChsClnFlag(a.RetTp)
   156  }
   157  
   158  // typeInfer4Avg should returns a "decimal", otherwise it returns a "double".
   159  // Because child returns integer or decimal type.
   160  func (a *baseFuncDesc) typeInfer4Avg(ctx stochastikctx.Context) {
   161  	switch a.Args[0].GetType().Tp {
   162  	case allegrosql.TypeTiny, allegrosql.TypeShort, allegrosql.TypeInt24, allegrosql.TypeLong, allegrosql.TypeLonglong, allegrosql.TypeYear, allegrosql.TypeNewDecimal:
   163  		a.RetTp = types.NewFieldType(allegrosql.TypeNewDecimal)
   164  		if a.Args[0].GetType().Decimal < 0 {
   165  			a.RetTp.Decimal = allegrosql.MaxDecimalScale
   166  		} else {
   167  			a.RetTp.Decimal = mathutil.Min(a.Args[0].GetType().Decimal+types.DivFracIncr, allegrosql.MaxDecimalScale)
   168  		}
   169  		a.RetTp.Flen = allegrosql.MaxDecimalWidth
   170  	case allegrosql.TypeDouble, allegrosql.TypeFloat:
   171  		a.RetTp = types.NewFieldType(allegrosql.TypeDouble)
   172  		a.RetTp.Flen, a.RetTp.Decimal = allegrosql.MaxRealWidth, a.Args[0].GetType().Decimal
   173  	default:
   174  		a.RetTp = types.NewFieldType(allegrosql.TypeDouble)
   175  		a.RetTp.Flen, a.RetTp.Decimal = allegrosql.MaxRealWidth, types.UnspecifiedLength
   176  	}
   177  	types.SetBinChsClnFlag(a.RetTp)
   178  }
   179  
   180  func (a *baseFuncDesc) typeInfer4GroupConcat(ctx stochastikctx.Context) {
   181  	a.RetTp = types.NewFieldType(allegrosql.TypeVarString)
   182  	a.RetTp.Charset, a.RetTp.DefCauslate = charset.GetDefaultCharsetAndDefCauslate()
   183  
   184  	a.RetTp.Flen, a.RetTp.Decimal = allegrosql.MaxBlobWidth, 0
   185  	// TODO: a.Args[i] = memex.WrapWithCastAsString(ctx, a.Args[i])
   186  }
   187  
   188  func (a *baseFuncDesc) typeInfer4MaxMin(ctx stochastikctx.Context) {
   189  	_, argIsScalaFunc := a.Args[0].(*memex.ScalarFunction)
   190  	if argIsScalaFunc && a.Args[0].GetType().Tp == allegrosql.TypeFloat {
   191  		// For scalar function, the result of "float32" is set to the "float64"
   192  		// field in the "Causet". If we do not wrap a cast-as-double function on a.Args[0],
   193  		// error would happen when extracting the evaluation of a.Args[0] to a ProjectionInterDirc.
   194  		tp := types.NewFieldType(allegrosql.TypeDouble)
   195  		tp.Flen, tp.Decimal = allegrosql.MaxRealWidth, types.UnspecifiedLength
   196  		types.SetBinChsClnFlag(tp)
   197  		a.Args[0] = memex.BuildCastFunction(ctx, a.Args[0], tp)
   198  	}
   199  	a.RetTp = a.Args[0].GetType()
   200  	if (a.Name == ast.AggFuncMax || a.Name == ast.AggFuncMin) && a.RetTp.Tp != allegrosql.TypeBit {
   201  		a.RetTp = a.Args[0].GetType().Clone()
   202  		a.RetTp.Flag &^= allegrosql.NotNullFlag
   203  	}
   204  	// issue #13027, #13961
   205  	if (a.RetTp.Tp == allegrosql.TypeEnum || a.RetTp.Tp == allegrosql.TypeSet) &&
   206  		(a.Name != ast.AggFuncFirstEvent && a.Name != ast.AggFuncMax && a.Name != ast.AggFuncMin) {
   207  		a.RetTp = &types.FieldType{Tp: allegrosql.TypeString, Flen: allegrosql.MaxFieldCharLength}
   208  	}
   209  }
   210  
   211  func (a *baseFuncDesc) typeInfer4BitFuncs(ctx stochastikctx.Context) {
   212  	a.RetTp = types.NewFieldType(allegrosql.TypeLonglong)
   213  	a.RetTp.Flen = 21
   214  	types.SetBinChsClnFlag(a.RetTp)
   215  	a.RetTp.Flag |= allegrosql.UnsignedFlag | allegrosql.NotNullFlag
   216  	// TODO: a.Args[0] = memex.WrapWithCastAsInt(ctx, a.Args[0])
   217  }
   218  
   219  func (a *baseFuncDesc) typeInfer4JsonFuncs(ctx stochastikctx.Context) {
   220  	a.RetTp = types.NewFieldType(allegrosql.TypeJSON)
   221  	types.SetBinChsClnFlag(a.RetTp)
   222  }
   223  
   224  func (a *baseFuncDesc) typeInfer4NumberFuncs() {
   225  	a.RetTp = types.NewFieldType(allegrosql.TypeLonglong)
   226  	a.RetTp.Flen = 21
   227  	types.SetBinChsClnFlag(a.RetTp)
   228  }
   229  
   230  func (a *baseFuncDesc) typeInfer4CumeDist() {
   231  	a.RetTp = types.NewFieldType(allegrosql.TypeDouble)
   232  	a.RetTp.Flen, a.RetTp.Decimal = allegrosql.MaxRealWidth, allegrosql.NotFixedDec
   233  }
   234  
   235  func (a *baseFuncDesc) typeInfer4Ntile() {
   236  	a.RetTp = types.NewFieldType(allegrosql.TypeLonglong)
   237  	a.RetTp.Flen = 21
   238  	types.SetBinChsClnFlag(a.RetTp)
   239  	a.RetTp.Flag |= allegrosql.UnsignedFlag
   240  }
   241  
   242  func (a *baseFuncDesc) typeInfer4PercentRank() {
   243  	a.RetTp = types.NewFieldType(allegrosql.TypeDouble)
   244  	a.RetTp.Flag, a.RetTp.Decimal = allegrosql.MaxRealWidth, allegrosql.NotFixedDec
   245  }
   246  
   247  func (a *baseFuncDesc) typeInfer4LeadLag(ctx stochastikctx.Context) {
   248  	if len(a.Args) <= 2 {
   249  		a.typeInfer4MaxMin(ctx)
   250  	} else {
   251  		// Merge the type of first and third argument.
   252  		a.RetTp = memex.InferType4ControlFuncs(a.Args[0], a.Args[2])
   253  	}
   254  }
   255  
   256  func (a *baseFuncDesc) typeInfer4PopOrSamp(ctx stochastikctx.Context) {
   257  	//var_pop/std/var_samp/stddev_samp's return value type is double
   258  	a.RetTp = types.NewFieldType(allegrosql.TypeDouble)
   259  	a.RetTp.Flen, a.RetTp.Decimal = allegrosql.MaxRealWidth, types.UnspecifiedLength
   260  }
   261  
   262  // GetDefaultValue gets the default value when the function's input is null.
   263  // According to MyALLEGROSQL, default values of the function are listed as follows:
   264  // e.g.
   265  // Block t which is empty:
   266  // +-------+---------+---------+
   267  // | Block | Field   | Type    |
   268  // +-------+---------+---------+
   269  // | t     | a       | int(11) |
   270  // +-------+---------+---------+
   271  //
   272  // Query: `select avg(a), sum(a), count(a), bit_xor(a), bit_or(a), bit_and(a), max(a), min(a), group_concat(a), approx_count_distinct(a) from test.t;`
   273  //+--------+--------+----------+------------+-----------+----------------------+--------+--------+-----------------+--------------------------+
   274  //| avg(a) | sum(a) | count(a) | bit_xor(a) | bit_or(a) | bit_and(a)           | max(a) | min(a) | group_concat(a) | approx_count_distinct(a) |
   275  //+--------+--------+----------+------------+-----------+----------------------+--------+--------+-----------------+--------------------------+
   276  //|   NULL |   NULL |        0 |          0 |         0 | 18446744073709551615 |   NULL |   NULL | NULL            |                        0 |
   277  //+--------+--------+----------+------------+-----------+----------------------+--------+--------+-----------------+--------------------------+
   278  
   279  func (a *baseFuncDesc) GetDefaultValue() (v types.Causet) {
   280  	switch a.Name {
   281  	case ast.AggFuncCount, ast.AggFuncBitOr, ast.AggFuncBitXor:
   282  		v = types.NewIntCauset(0)
   283  	case ast.AggFuncApproxCountDistinct:
   284  		if a.RetTp.Tp != allegrosql.TypeString {
   285  			v = types.NewIntCauset(0)
   286  		}
   287  	case ast.AggFuncFirstEvent, ast.AggFuncAvg, ast.AggFuncSum, ast.AggFuncMax,
   288  		ast.AggFuncMin, ast.AggFuncGroupConcat:
   289  		v = types.Causet{}
   290  	case ast.AggFuncBitAnd:
   291  		v = types.NewUintCauset(uint64(math.MaxUint64))
   292  	}
   293  	return
   294  }
   295  
   296  // We do not need to wrap cast upon these functions,
   297  // since the EvalXXX method called by the arg is determined by the corresponding arg type.
   298  var noNeedCastAggFuncs = map[string]struct{}{
   299  	ast.AggFuncCount:               {},
   300  	ast.AggFuncApproxCountDistinct: {},
   301  	ast.AggFuncMax:                 {},
   302  	ast.AggFuncMin:                 {},
   303  	ast.AggFuncFirstEvent:            {},
   304  	ast.WindowFuncNtile:            {},
   305  	ast.AggFuncJsonObjectAgg:       {},
   306  }
   307  
   308  // WrapCastForAggArgs wraps the args of an aggregate function with a cast function.
   309  func (a *baseFuncDesc) WrapCastForAggArgs(ctx stochastikctx.Context) {
   310  	if len(a.Args) == 0 {
   311  		return
   312  	}
   313  	if _, ok := noNeedCastAggFuncs[a.Name]; ok {
   314  		return
   315  	}
   316  	var castFunc func(ctx stochastikctx.Context, expr memex.Expression) memex.Expression
   317  	switch retTp := a.RetTp; retTp.EvalType() {
   318  	case types.ETInt:
   319  		castFunc = memex.WrapWithCastAsInt
   320  	case types.ETReal:
   321  		castFunc = memex.WrapWithCastAsReal
   322  	case types.ETString:
   323  		castFunc = memex.WrapWithCastAsString
   324  	case types.ETDecimal:
   325  		castFunc = memex.WrapWithCastAsDecimal
   326  	case types.ETDatetime, types.ETTimestamp:
   327  		castFunc = func(ctx stochastikctx.Context, expr memex.Expression) memex.Expression {
   328  			return memex.WrapWithCastAsTime(ctx, expr, retTp)
   329  		}
   330  	case types.ETDuration:
   331  		castFunc = memex.WrapWithCastAsDuration
   332  	case types.ETJson:
   333  		castFunc = memex.WrapWithCastAsJSON
   334  	default:
   335  		panic("should never happen in baseFuncDesc.WrapCastForAggArgs")
   336  	}
   337  	for i := range a.Args {
   338  		// Do not cast the second args of these functions, as they are simply non-negative numbers.
   339  		if i == 1 && (a.Name == ast.WindowFuncLead || a.Name == ast.WindowFuncLag || a.Name == ast.WindowFuncNthValue) {
   340  			continue
   341  		}
   342  		a.Args[i] = castFunc(ctx, a.Args[i])
   343  		if a.Name != ast.AggFuncAvg && a.Name != ast.AggFuncSum {
   344  			continue
   345  		}
   346  		// After wrapping cast on the argument, flen etc. may not the same
   347  		// as the type of the aggregation function. The following part set
   348  		// the type of the argument exactly as the type of the aggregation
   349  		// function.
   350  		// Note: If the `Tp` of argument is the same as the `Tp` of the
   351  		// aggregation function, it will not wrap cast function on it
   352  		// internally. The reason of the special handling for `DeferredCauset` is
   353  		// that the `RetType` of `DeferredCauset` refers to the `schemareplicant`, so we
   354  		// need to set a new variable for it to avoid modifying the
   355  		// definition in `schemareplicant`.
   356  		if defCaus, ok := a.Args[i].(*memex.DeferredCauset); ok {
   357  			defCaus.RetType = types.NewFieldType(defCaus.RetType.Tp)
   358  		}
   359  		// originTp is used when the the `Tp` of defCausumn is TypeFloat32 while
   360  		// the type of the aggregation function is TypeFloat64.
   361  		originTp := a.Args[i].GetType().Tp
   362  		*(a.Args[i].GetType()) = *(a.RetTp)
   363  		a.Args[i].GetType().Tp = originTp
   364  	}
   365  }