github.com/whtcorpsinc/milevadb-prod@v0.0.0-20211104133533-f57f4be3b597/dbs/memristed/memex/aggregation/base_func.go (about) 1 // Copyright 2020 WHTCORPS INC, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 package aggregation 15 16 import ( 17 "bytes" 18 "math" 19 "strings" 20 21 "github.com/cznic/mathutil" 22 "github.com/whtcorpsinc/errors" 23 "github.com/whtcorpsinc/BerolinaSQL/ast" 24 "github.com/whtcorpsinc/BerolinaSQL/charset" 25 "github.com/whtcorpsinc/BerolinaSQL/allegrosql" 26 "github.com/whtcorpsinc/milevadb/memex" 27 "github.com/whtcorpsinc/milevadb/stochastikctx" 28 "github.com/whtcorpsinc/milevadb/types" 29 ) 30 31 // baseFuncDesc describes an function signature, only used in causet. 32 type baseFuncDesc struct { 33 // Name represents the function name. 34 Name string 35 // Args represents the arguments of the function. 36 Args []memex.Expression 37 // RetTp represents the return type of the function. 38 RetTp *types.FieldType 39 } 40 41 func newBaseFuncDesc(ctx stochastikctx.Context, name string, args []memex.Expression) (baseFuncDesc, error) { 42 b := baseFuncDesc{Name: strings.ToLower(name), Args: args} 43 err := b.typeInfer(ctx) 44 return b, err 45 } 46 47 func (a *baseFuncDesc) equal(ctx stochastikctx.Context, other *baseFuncDesc) bool { 48 if a.Name != other.Name || len(a.Args) != len(other.Args) { 49 return false 50 } 51 for i := range a.Args { 52 if !a.Args[i].Equal(ctx, other.Args[i]) { 53 return false 54 } 55 } 56 return true 57 } 58 59 func (a *baseFuncDesc) clone() *baseFuncDesc { 60 clone := *a 61 newTp := *a.RetTp 62 clone.RetTp = &newTp 63 clone.Args = make([]memex.Expression, len(a.Args)) 64 for i := range a.Args { 65 clone.Args[i] = a.Args[i].Clone() 66 } 67 return &clone 68 } 69 70 // String implements the fmt.Stringer interface. 71 func (a *baseFuncDesc) String() string { 72 buffer := bytes.NewBufferString(a.Name) 73 buffer.WriteString("(") 74 for i, arg := range a.Args { 75 buffer.WriteString(arg.String()) 76 if i+1 != len(a.Args) { 77 buffer.WriteString(", ") 78 } 79 } 80 buffer.WriteString(")") 81 return buffer.String() 82 } 83 84 // typeInfer infers the arguments and return types of an function. 85 func (a *baseFuncDesc) typeInfer(ctx stochastikctx.Context) error { 86 switch a.Name { 87 case ast.AggFuncCount: 88 a.typeInfer4Count(ctx) 89 case ast.AggFuncApproxCountDistinct: 90 a.typeInfer4ApproxCountDistinct(ctx) 91 case ast.AggFuncSum: 92 a.typeInfer4Sum(ctx) 93 case ast.AggFuncAvg: 94 a.typeInfer4Avg(ctx) 95 case ast.AggFuncGroupConcat: 96 a.typeInfer4GroupConcat(ctx) 97 case ast.AggFuncMax, ast.AggFuncMin, ast.AggFuncFirstEvent, 98 ast.WindowFuncFirstValue, ast.WindowFuncLastValue, ast.WindowFuncNthValue: 99 a.typeInfer4MaxMin(ctx) 100 case ast.AggFuncBitAnd, ast.AggFuncBitOr, ast.AggFuncBitXor: 101 a.typeInfer4BitFuncs(ctx) 102 case ast.WindowFuncEventNumber, ast.WindowFuncRank, ast.WindowFuncDenseRank: 103 a.typeInfer4NumberFuncs() 104 case ast.WindowFuncCumeDist: 105 a.typeInfer4CumeDist() 106 case ast.WindowFuncNtile: 107 a.typeInfer4Ntile() 108 case ast.WindowFuncPercentRank: 109 a.typeInfer4PercentRank() 110 case ast.WindowFuncLead, ast.WindowFuncLag: 111 a.typeInfer4LeadLag(ctx) 112 case ast.AggFuncVarPop, ast.AggFuncStddevPop, ast.AggFuncVarSamp, ast.AggFuncStddevSamp: 113 a.typeInfer4PopOrSamp(ctx) 114 case ast.AggFuncJsonObjectAgg: 115 a.typeInfer4JsonFuncs(ctx) 116 default: 117 return errors.Errorf("unsupported agg function: %s", a.Name) 118 } 119 return nil 120 } 121 122 func (a *baseFuncDesc) typeInfer4Count(ctx stochastikctx.Context) { 123 a.RetTp = types.NewFieldType(allegrosql.TypeLonglong) 124 a.RetTp.Flen = 21 125 a.RetTp.Decimal = 0 126 // count never returns null 127 a.RetTp.Flag |= allegrosql.NotNullFlag 128 types.SetBinChsClnFlag(a.RetTp) 129 } 130 131 func (a *baseFuncDesc) typeInfer4ApproxCountDistinct(ctx stochastikctx.Context) { 132 a.typeInfer4Count(ctx) 133 } 134 135 // typeInfer4Sum should returns a "decimal", otherwise it returns a "double". 136 // Because child returns integer or decimal type. 137 func (a *baseFuncDesc) typeInfer4Sum(ctx stochastikctx.Context) { 138 switch a.Args[0].GetType().Tp { 139 case allegrosql.TypeTiny, allegrosql.TypeShort, allegrosql.TypeInt24, allegrosql.TypeLong, allegrosql.TypeLonglong, allegrosql.TypeYear: 140 a.RetTp = types.NewFieldType(allegrosql.TypeNewDecimal) 141 a.RetTp.Flen, a.RetTp.Decimal = allegrosql.MaxDecimalWidth, 0 142 case allegrosql.TypeNewDecimal: 143 a.RetTp = types.NewFieldType(allegrosql.TypeNewDecimal) 144 a.RetTp.Flen, a.RetTp.Decimal = allegrosql.MaxDecimalWidth, a.Args[0].GetType().Decimal 145 if a.RetTp.Decimal < 0 || a.RetTp.Decimal > allegrosql.MaxDecimalScale { 146 a.RetTp.Decimal = allegrosql.MaxDecimalScale 147 } 148 case allegrosql.TypeDouble, allegrosql.TypeFloat: 149 a.RetTp = types.NewFieldType(allegrosql.TypeDouble) 150 a.RetTp.Flen, a.RetTp.Decimal = allegrosql.MaxRealWidth, a.Args[0].GetType().Decimal 151 default: 152 a.RetTp = types.NewFieldType(allegrosql.TypeDouble) 153 a.RetTp.Flen, a.RetTp.Decimal = allegrosql.MaxRealWidth, types.UnspecifiedLength 154 } 155 types.SetBinChsClnFlag(a.RetTp) 156 } 157 158 // typeInfer4Avg should returns a "decimal", otherwise it returns a "double". 159 // Because child returns integer or decimal type. 160 func (a *baseFuncDesc) typeInfer4Avg(ctx stochastikctx.Context) { 161 switch a.Args[0].GetType().Tp { 162 case allegrosql.TypeTiny, allegrosql.TypeShort, allegrosql.TypeInt24, allegrosql.TypeLong, allegrosql.TypeLonglong, allegrosql.TypeYear, allegrosql.TypeNewDecimal: 163 a.RetTp = types.NewFieldType(allegrosql.TypeNewDecimal) 164 if a.Args[0].GetType().Decimal < 0 { 165 a.RetTp.Decimal = allegrosql.MaxDecimalScale 166 } else { 167 a.RetTp.Decimal = mathutil.Min(a.Args[0].GetType().Decimal+types.DivFracIncr, allegrosql.MaxDecimalScale) 168 } 169 a.RetTp.Flen = allegrosql.MaxDecimalWidth 170 case allegrosql.TypeDouble, allegrosql.TypeFloat: 171 a.RetTp = types.NewFieldType(allegrosql.TypeDouble) 172 a.RetTp.Flen, a.RetTp.Decimal = allegrosql.MaxRealWidth, a.Args[0].GetType().Decimal 173 default: 174 a.RetTp = types.NewFieldType(allegrosql.TypeDouble) 175 a.RetTp.Flen, a.RetTp.Decimal = allegrosql.MaxRealWidth, types.UnspecifiedLength 176 } 177 types.SetBinChsClnFlag(a.RetTp) 178 } 179 180 func (a *baseFuncDesc) typeInfer4GroupConcat(ctx stochastikctx.Context) { 181 a.RetTp = types.NewFieldType(allegrosql.TypeVarString) 182 a.RetTp.Charset, a.RetTp.DefCauslate = charset.GetDefaultCharsetAndDefCauslate() 183 184 a.RetTp.Flen, a.RetTp.Decimal = allegrosql.MaxBlobWidth, 0 185 // TODO: a.Args[i] = memex.WrapWithCastAsString(ctx, a.Args[i]) 186 } 187 188 func (a *baseFuncDesc) typeInfer4MaxMin(ctx stochastikctx.Context) { 189 _, argIsScalaFunc := a.Args[0].(*memex.ScalarFunction) 190 if argIsScalaFunc && a.Args[0].GetType().Tp == allegrosql.TypeFloat { 191 // For scalar function, the result of "float32" is set to the "float64" 192 // field in the "Causet". If we do not wrap a cast-as-double function on a.Args[0], 193 // error would happen when extracting the evaluation of a.Args[0] to a ProjectionInterDirc. 194 tp := types.NewFieldType(allegrosql.TypeDouble) 195 tp.Flen, tp.Decimal = allegrosql.MaxRealWidth, types.UnspecifiedLength 196 types.SetBinChsClnFlag(tp) 197 a.Args[0] = memex.BuildCastFunction(ctx, a.Args[0], tp) 198 } 199 a.RetTp = a.Args[0].GetType() 200 if (a.Name == ast.AggFuncMax || a.Name == ast.AggFuncMin) && a.RetTp.Tp != allegrosql.TypeBit { 201 a.RetTp = a.Args[0].GetType().Clone() 202 a.RetTp.Flag &^= allegrosql.NotNullFlag 203 } 204 // issue #13027, #13961 205 if (a.RetTp.Tp == allegrosql.TypeEnum || a.RetTp.Tp == allegrosql.TypeSet) && 206 (a.Name != ast.AggFuncFirstEvent && a.Name != ast.AggFuncMax && a.Name != ast.AggFuncMin) { 207 a.RetTp = &types.FieldType{Tp: allegrosql.TypeString, Flen: allegrosql.MaxFieldCharLength} 208 } 209 } 210 211 func (a *baseFuncDesc) typeInfer4BitFuncs(ctx stochastikctx.Context) { 212 a.RetTp = types.NewFieldType(allegrosql.TypeLonglong) 213 a.RetTp.Flen = 21 214 types.SetBinChsClnFlag(a.RetTp) 215 a.RetTp.Flag |= allegrosql.UnsignedFlag | allegrosql.NotNullFlag 216 // TODO: a.Args[0] = memex.WrapWithCastAsInt(ctx, a.Args[0]) 217 } 218 219 func (a *baseFuncDesc) typeInfer4JsonFuncs(ctx stochastikctx.Context) { 220 a.RetTp = types.NewFieldType(allegrosql.TypeJSON) 221 types.SetBinChsClnFlag(a.RetTp) 222 } 223 224 func (a *baseFuncDesc) typeInfer4NumberFuncs() { 225 a.RetTp = types.NewFieldType(allegrosql.TypeLonglong) 226 a.RetTp.Flen = 21 227 types.SetBinChsClnFlag(a.RetTp) 228 } 229 230 func (a *baseFuncDesc) typeInfer4CumeDist() { 231 a.RetTp = types.NewFieldType(allegrosql.TypeDouble) 232 a.RetTp.Flen, a.RetTp.Decimal = allegrosql.MaxRealWidth, allegrosql.NotFixedDec 233 } 234 235 func (a *baseFuncDesc) typeInfer4Ntile() { 236 a.RetTp = types.NewFieldType(allegrosql.TypeLonglong) 237 a.RetTp.Flen = 21 238 types.SetBinChsClnFlag(a.RetTp) 239 a.RetTp.Flag |= allegrosql.UnsignedFlag 240 } 241 242 func (a *baseFuncDesc) typeInfer4PercentRank() { 243 a.RetTp = types.NewFieldType(allegrosql.TypeDouble) 244 a.RetTp.Flag, a.RetTp.Decimal = allegrosql.MaxRealWidth, allegrosql.NotFixedDec 245 } 246 247 func (a *baseFuncDesc) typeInfer4LeadLag(ctx stochastikctx.Context) { 248 if len(a.Args) <= 2 { 249 a.typeInfer4MaxMin(ctx) 250 } else { 251 // Merge the type of first and third argument. 252 a.RetTp = memex.InferType4ControlFuncs(a.Args[0], a.Args[2]) 253 } 254 } 255 256 func (a *baseFuncDesc) typeInfer4PopOrSamp(ctx stochastikctx.Context) { 257 //var_pop/std/var_samp/stddev_samp's return value type is double 258 a.RetTp = types.NewFieldType(allegrosql.TypeDouble) 259 a.RetTp.Flen, a.RetTp.Decimal = allegrosql.MaxRealWidth, types.UnspecifiedLength 260 } 261 262 // GetDefaultValue gets the default value when the function's input is null. 263 // According to MyALLEGROSQL, default values of the function are listed as follows: 264 // e.g. 265 // Block t which is empty: 266 // +-------+---------+---------+ 267 // | Block | Field | Type | 268 // +-------+---------+---------+ 269 // | t | a | int(11) | 270 // +-------+---------+---------+ 271 // 272 // Query: `select avg(a), sum(a), count(a), bit_xor(a), bit_or(a), bit_and(a), max(a), min(a), group_concat(a), approx_count_distinct(a) from test.t;` 273 //+--------+--------+----------+------------+-----------+----------------------+--------+--------+-----------------+--------------------------+ 274 //| avg(a) | sum(a) | count(a) | bit_xor(a) | bit_or(a) | bit_and(a) | max(a) | min(a) | group_concat(a) | approx_count_distinct(a) | 275 //+--------+--------+----------+------------+-----------+----------------------+--------+--------+-----------------+--------------------------+ 276 //| NULL | NULL | 0 | 0 | 0 | 18446744073709551615 | NULL | NULL | NULL | 0 | 277 //+--------+--------+----------+------------+-----------+----------------------+--------+--------+-----------------+--------------------------+ 278 279 func (a *baseFuncDesc) GetDefaultValue() (v types.Causet) { 280 switch a.Name { 281 case ast.AggFuncCount, ast.AggFuncBitOr, ast.AggFuncBitXor: 282 v = types.NewIntCauset(0) 283 case ast.AggFuncApproxCountDistinct: 284 if a.RetTp.Tp != allegrosql.TypeString { 285 v = types.NewIntCauset(0) 286 } 287 case ast.AggFuncFirstEvent, ast.AggFuncAvg, ast.AggFuncSum, ast.AggFuncMax, 288 ast.AggFuncMin, ast.AggFuncGroupConcat: 289 v = types.Causet{} 290 case ast.AggFuncBitAnd: 291 v = types.NewUintCauset(uint64(math.MaxUint64)) 292 } 293 return 294 } 295 296 // We do not need to wrap cast upon these functions, 297 // since the EvalXXX method called by the arg is determined by the corresponding arg type. 298 var noNeedCastAggFuncs = map[string]struct{}{ 299 ast.AggFuncCount: {}, 300 ast.AggFuncApproxCountDistinct: {}, 301 ast.AggFuncMax: {}, 302 ast.AggFuncMin: {}, 303 ast.AggFuncFirstEvent: {}, 304 ast.WindowFuncNtile: {}, 305 ast.AggFuncJsonObjectAgg: {}, 306 } 307 308 // WrapCastForAggArgs wraps the args of an aggregate function with a cast function. 309 func (a *baseFuncDesc) WrapCastForAggArgs(ctx stochastikctx.Context) { 310 if len(a.Args) == 0 { 311 return 312 } 313 if _, ok := noNeedCastAggFuncs[a.Name]; ok { 314 return 315 } 316 var castFunc func(ctx stochastikctx.Context, expr memex.Expression) memex.Expression 317 switch retTp := a.RetTp; retTp.EvalType() { 318 case types.ETInt: 319 castFunc = memex.WrapWithCastAsInt 320 case types.ETReal: 321 castFunc = memex.WrapWithCastAsReal 322 case types.ETString: 323 castFunc = memex.WrapWithCastAsString 324 case types.ETDecimal: 325 castFunc = memex.WrapWithCastAsDecimal 326 case types.ETDatetime, types.ETTimestamp: 327 castFunc = func(ctx stochastikctx.Context, expr memex.Expression) memex.Expression { 328 return memex.WrapWithCastAsTime(ctx, expr, retTp) 329 } 330 case types.ETDuration: 331 castFunc = memex.WrapWithCastAsDuration 332 case types.ETJson: 333 castFunc = memex.WrapWithCastAsJSON 334 default: 335 panic("should never happen in baseFuncDesc.WrapCastForAggArgs") 336 } 337 for i := range a.Args { 338 // Do not cast the second args of these functions, as they are simply non-negative numbers. 339 if i == 1 && (a.Name == ast.WindowFuncLead || a.Name == ast.WindowFuncLag || a.Name == ast.WindowFuncNthValue) { 340 continue 341 } 342 a.Args[i] = castFunc(ctx, a.Args[i]) 343 if a.Name != ast.AggFuncAvg && a.Name != ast.AggFuncSum { 344 continue 345 } 346 // After wrapping cast on the argument, flen etc. may not the same 347 // as the type of the aggregation function. The following part set 348 // the type of the argument exactly as the type of the aggregation 349 // function. 350 // Note: If the `Tp` of argument is the same as the `Tp` of the 351 // aggregation function, it will not wrap cast function on it 352 // internally. The reason of the special handling for `DeferredCauset` is 353 // that the `RetType` of `DeferredCauset` refers to the `schemareplicant`, so we 354 // need to set a new variable for it to avoid modifying the 355 // definition in `schemareplicant`. 356 if defCaus, ok := a.Args[i].(*memex.DeferredCauset); ok { 357 defCaus.RetType = types.NewFieldType(defCaus.RetType.Tp) 358 } 359 // originTp is used when the the `Tp` of defCausumn is TypeFloat32 while 360 // the type of the aggregation function is TypeFloat64. 361 originTp := a.Args[i].GetType().Tp 362 *(a.Args[i].GetType()) = *(a.RetTp) 363 a.Args[i].GetType().Tp = originTp 364 } 365 }