github.com/whtcorpsinc/milevadb-prod@v0.0.0-20211104133533-f57f4be3b597/dbs/memristed/memex/scalar_function.go (about) 1 // Copyright 2020 WHTCORPS INC, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 package memex 15 16 import ( 17 "bytes" 18 "fmt" 19 20 "github.com/whtcorpsinc/errors" 21 "github.com/whtcorpsinc/BerolinaSQL/ast" 22 "github.com/whtcorpsinc/BerolinaSQL/perceptron" 23 "github.com/whtcorpsinc/BerolinaSQL/allegrosql" 24 "github.com/whtcorpsinc/BerolinaSQL/terror" 25 "github.com/whtcorpsinc/milevadb/stochastikctx" 26 "github.com/whtcorpsinc/milevadb/stochastikctx/stmtctx" 27 "github.com/whtcorpsinc/milevadb/types" 28 "github.com/whtcorpsinc/milevadb/types/json" 29 "github.com/whtcorpsinc/milevadb/soliton/chunk" 30 "github.com/whtcorpsinc/milevadb/soliton/codec" 31 "github.com/whtcorpsinc/milevadb/soliton/replog" 32 ) 33 34 // error definitions. 35 var ( 36 ErrNoDB = terror.ClassOptimizer.New(allegrosql.ErrNoDB, allegrosql.MyALLEGROSQLErrName[allegrosql.ErrNoDB]) 37 ) 38 39 // ScalarFunction is the function that returns a value. 40 type ScalarFunction struct { 41 FuncName perceptron.CIStr 42 // RetType is the type that ScalarFunction returns. 43 // TODO: Implement type inference here, now we use ast's return type temporarily. 44 RetType *types.FieldType 45 Function builtinFunc 46 hashcode []byte 47 } 48 49 // VecEvalInt evaluates this memex in a vectorized manner. 50 func (sf *ScalarFunction) VecEvalInt(ctx stochastikctx.Context, input *chunk.Chunk, result *chunk.DeferredCauset) error { 51 return sf.Function.vecEvalInt(input, result) 52 } 53 54 // VecEvalReal evaluates this memex in a vectorized manner. 55 func (sf *ScalarFunction) VecEvalReal(ctx stochastikctx.Context, input *chunk.Chunk, result *chunk.DeferredCauset) error { 56 return sf.Function.vecEvalReal(input, result) 57 } 58 59 // VecEvalString evaluates this memex in a vectorized manner. 60 func (sf *ScalarFunction) VecEvalString(ctx stochastikctx.Context, input *chunk.Chunk, result *chunk.DeferredCauset) error { 61 return sf.Function.vecEvalString(input, result) 62 } 63 64 // VecEvalDecimal evaluates this memex in a vectorized manner. 65 func (sf *ScalarFunction) VecEvalDecimal(ctx stochastikctx.Context, input *chunk.Chunk, result *chunk.DeferredCauset) error { 66 return sf.Function.vecEvalDecimal(input, result) 67 } 68 69 // VecEvalTime evaluates this memex in a vectorized manner. 70 func (sf *ScalarFunction) VecEvalTime(ctx stochastikctx.Context, input *chunk.Chunk, result *chunk.DeferredCauset) error { 71 return sf.Function.vecEvalTime(input, result) 72 } 73 74 // VecEvalDuration evaluates this memex in a vectorized manner. 75 func (sf *ScalarFunction) VecEvalDuration(ctx stochastikctx.Context, input *chunk.Chunk, result *chunk.DeferredCauset) error { 76 return sf.Function.vecEvalDuration(input, result) 77 } 78 79 // VecEvalJSON evaluates this memex in a vectorized manner. 80 func (sf *ScalarFunction) VecEvalJSON(ctx stochastikctx.Context, input *chunk.Chunk, result *chunk.DeferredCauset) error { 81 return sf.Function.vecEvalJSON(input, result) 82 } 83 84 // GetArgs gets arguments of function. 85 func (sf *ScalarFunction) GetArgs() []Expression { 86 return sf.Function.getArgs() 87 } 88 89 // Vectorized returns if this memex supports vectorized evaluation. 90 func (sf *ScalarFunction) Vectorized() bool { 91 return sf.Function.vectorized() && sf.Function.isChildrenVectorized() 92 } 93 94 // SupportReverseEval returns if this memex supports reversed evaluation. 95 func (sf *ScalarFunction) SupportReverseEval() bool { 96 switch sf.RetType.Tp { 97 case allegrosql.TypeShort, allegrosql.TypeLong, allegrosql.TypeLonglong, 98 allegrosql.TypeFloat, allegrosql.TypeDouble, allegrosql.TypeNewDecimal: 99 return sf.Function.supportReverseEval() && sf.Function.isChildrenReversed() 100 } 101 return false 102 } 103 104 // ReverseEval evaluates the only one defCausumn value with given function result. 105 func (sf *ScalarFunction) ReverseEval(sc *stmtctx.StatementContext, res types.Causet, rType types.RoundingType) (val types.Causet, err error) { 106 return sf.Function.reverseEval(sc, res, rType) 107 } 108 109 // GetCtx gets the context of function. 110 func (sf *ScalarFunction) GetCtx() stochastikctx.Context { 111 return sf.Function.getCtx() 112 } 113 114 // String implements fmt.Stringer interface. 115 func (sf *ScalarFunction) String() string { 116 var buffer bytes.Buffer 117 fmt.Fprintf(&buffer, "%s(", sf.FuncName.L) 118 switch sf.FuncName.L { 119 case ast.Cast: 120 for _, arg := range sf.GetArgs() { 121 buffer.WriteString(arg.String()) 122 buffer.WriteString(", ") 123 buffer.WriteString(sf.RetType.String()) 124 } 125 default: 126 for i, arg := range sf.GetArgs() { 127 buffer.WriteString(arg.String()) 128 if i+1 != len(sf.GetArgs()) { 129 buffer.WriteString(", ") 130 } 131 } 132 } 133 buffer.WriteString(")") 134 return buffer.String() 135 } 136 137 // MarshalJSON implements json.Marshaler interface. 138 func (sf *ScalarFunction) MarshalJSON() ([]byte, error) { 139 return []byte(fmt.Sprintf("%q", sf)), nil 140 } 141 142 // typeInferForNull infers the NULL constants field type and set the field type 143 // of NULL constant same as other non-null operands. 144 func typeInferForNull(args []Expression) { 145 if len(args) < 2 { 146 return 147 } 148 var isNull = func(expr Expression) bool { 149 cons, ok := expr.(*Constant) 150 return ok && cons.RetType.Tp == allegrosql.TypeNull && cons.Value.IsNull() 151 } 152 // Infer the actual field type of the NULL constant. 153 var retFieldTp *types.FieldType 154 var hasNullArg bool 155 for _, arg := range args { 156 isNullArg := isNull(arg) 157 if !isNullArg && retFieldTp == nil { 158 retFieldTp = arg.GetType() 159 } 160 hasNullArg = hasNullArg || isNullArg 161 // Break if there are both NULL and non-NULL memex 162 if hasNullArg && retFieldTp != nil { 163 break 164 } 165 } 166 if !hasNullArg || retFieldTp == nil { 167 return 168 } 169 for _, arg := range args { 170 if isNull(arg) { 171 *arg.GetType() = *retFieldTp 172 } 173 } 174 } 175 176 // newFunctionImpl creates a new scalar function or constant. 177 // fold: 1 means folding constants, while 0 means not, 178 // -1 means try to fold constants if without errors/warnings, otherwise not. 179 func newFunctionImpl(ctx stochastikctx.Context, fold int, funcName string, retType *types.FieldType, args ...Expression) (Expression, error) { 180 if retType == nil { 181 return nil, errors.Errorf("RetType cannot be nil for ScalarFunction.") 182 } 183 if funcName == ast.Cast { 184 return BuildCastFunction(ctx, args[0], retType), nil 185 } 186 fc, ok := funcs[funcName] 187 if !ok { 188 EDB := ctx.GetStochastikVars().CurrentDB 189 if EDB == "" { 190 return nil, errors.Trace(ErrNoDB) 191 } 192 193 return nil, errFunctionNotExists.GenWithStackByArgs("FUNCTION", EDB+"."+funcName) 194 } 195 if !ctx.GetStochastikVars().EnableNoopFuncs { 196 if _, ok := noopFuncs[funcName]; ok { 197 return nil, ErrFunctionsNoopImpl.GenWithStackByArgs(funcName) 198 } 199 } 200 funcArgs := make([]Expression, len(args)) 201 copy(funcArgs, args) 202 typeInferForNull(funcArgs) 203 f, err := fc.getFunction(ctx, funcArgs) 204 if err != nil { 205 return nil, err 206 } 207 if builtinRetTp := f.getRetTp(); builtinRetTp.Tp != allegrosql.TypeUnspecified || retType.Tp == allegrosql.TypeUnspecified { 208 retType = builtinRetTp 209 } 210 sf := &ScalarFunction{ 211 FuncName: perceptron.NewCIStr(funcName), 212 RetType: retType, 213 Function: f, 214 } 215 if fold == 1 { 216 return FoldConstant(sf), nil 217 } else if fold == -1 { 218 // try to fold constants, and return the original function if errors/warnings occur 219 sc := ctx.GetStochastikVars().StmtCtx 220 beforeWarns := sc.WarningCount() 221 newSf := FoldConstant(sf) 222 afterWarns := sc.WarningCount() 223 if afterWarns > beforeWarns { 224 sc.TruncateWarnings(int(beforeWarns)) 225 return sf, nil 226 } 227 return newSf, nil 228 } 229 return sf, nil 230 } 231 232 // NewFunction creates a new scalar function or constant via a constant folding. 233 func NewFunction(ctx stochastikctx.Context, funcName string, retType *types.FieldType, args ...Expression) (Expression, error) { 234 return newFunctionImpl(ctx, 1, funcName, retType, args...) 235 } 236 237 // NewFunctionBase creates a new scalar function with no constant folding. 238 func NewFunctionBase(ctx stochastikctx.Context, funcName string, retType *types.FieldType, args ...Expression) (Expression, error) { 239 return newFunctionImpl(ctx, 0, funcName, retType, args...) 240 } 241 242 // NewFunctionTryFold creates a new scalar function with trying constant folding. 243 func NewFunctionTryFold(ctx stochastikctx.Context, funcName string, retType *types.FieldType, args ...Expression) (Expression, error) { 244 return newFunctionImpl(ctx, -1, funcName, retType, args...) 245 } 246 247 // NewFunctionInternal is similar to NewFunction, but do not returns error, should only be used internally. 248 func NewFunctionInternal(ctx stochastikctx.Context, funcName string, retType *types.FieldType, args ...Expression) Expression { 249 expr, err := NewFunction(ctx, funcName, retType, args...) 250 terror.Log(err) 251 return expr 252 } 253 254 // ScalarFuncs2Exprs converts []*ScalarFunction to []Expression. 255 func ScalarFuncs2Exprs(funcs []*ScalarFunction) []Expression { 256 result := make([]Expression, 0, len(funcs)) 257 for _, defCaus := range funcs { 258 result = append(result, defCaus) 259 } 260 return result 261 } 262 263 // Clone implements Expression interface. 264 func (sf *ScalarFunction) Clone() Expression { 265 c := &ScalarFunction{ 266 FuncName: sf.FuncName, 267 RetType: sf.RetType, 268 Function: sf.Function.Clone(), 269 hashcode: sf.hashcode, 270 } 271 c.SetCharsetAndDefCauslation(sf.CharsetAndDefCauslation(sf.GetCtx())) 272 c.SetCoercibility(sf.Coercibility()) 273 return c 274 } 275 276 // GetType implements Expression interface. 277 func (sf *ScalarFunction) GetType() *types.FieldType { 278 return sf.RetType 279 } 280 281 // Equal implements Expression interface. 282 func (sf *ScalarFunction) Equal(ctx stochastikctx.Context, e Expression) bool { 283 fun, ok := e.(*ScalarFunction) 284 if !ok { 285 return false 286 } 287 if sf.FuncName.L != fun.FuncName.L { 288 return false 289 } 290 return sf.Function.equal(fun.Function) 291 } 292 293 // IsCorrelated implements Expression interface. 294 func (sf *ScalarFunction) IsCorrelated() bool { 295 for _, arg := range sf.GetArgs() { 296 if arg.IsCorrelated() { 297 return true 298 } 299 } 300 return false 301 } 302 303 // ConstItem implements Expression interface. 304 func (sf *ScalarFunction) ConstItem(sc *stmtctx.StatementContext) bool { 305 // Note: some unfoldable functions are deterministic, we use unFoldableFunctions here for simplification. 306 if _, ok := unFoldableFunctions[sf.FuncName.L]; ok { 307 return false 308 } 309 for _, arg := range sf.GetArgs() { 310 if !arg.ConstItem(sc) { 311 return false 312 } 313 } 314 return true 315 } 316 317 // Decorrelate implements Expression interface. 318 func (sf *ScalarFunction) Decorrelate(schemaReplicant *Schema) Expression { 319 for i, arg := range sf.GetArgs() { 320 sf.GetArgs()[i] = arg.Decorrelate(schemaReplicant) 321 } 322 return sf 323 } 324 325 // Eval implements Expression interface. 326 func (sf *ScalarFunction) Eval(event chunk.Event) (d types.Causet, err error) { 327 var ( 328 res interface{} 329 isNull bool 330 ) 331 switch tp, evalType := sf.GetType(), sf.GetType().EvalType(); evalType { 332 case types.ETInt: 333 var intRes int64 334 intRes, isNull, err = sf.EvalInt(sf.GetCtx(), event) 335 if allegrosql.HasUnsignedFlag(tp.Flag) { 336 res = uint64(intRes) 337 } else { 338 res = intRes 339 } 340 case types.ETReal: 341 res, isNull, err = sf.EvalReal(sf.GetCtx(), event) 342 case types.ETDecimal: 343 res, isNull, err = sf.EvalDecimal(sf.GetCtx(), event) 344 case types.ETDatetime, types.ETTimestamp: 345 res, isNull, err = sf.EvalTime(sf.GetCtx(), event) 346 case types.ETDuration: 347 res, isNull, err = sf.EvalDuration(sf.GetCtx(), event) 348 case types.ETJson: 349 res, isNull, err = sf.EvalJSON(sf.GetCtx(), event) 350 case types.ETString: 351 res, isNull, err = sf.EvalString(sf.GetCtx(), event) 352 } 353 354 if isNull || err != nil { 355 d.SetNull() 356 return d, err 357 } 358 d.SetValue(res, sf.RetType) 359 return 360 } 361 362 // EvalInt implements Expression interface. 363 func (sf *ScalarFunction) EvalInt(ctx stochastikctx.Context, event chunk.Event) (int64, bool, error) { 364 if f, ok := sf.Function.(builtinFuncNew); ok { 365 return f.evalIntWithCtx(ctx, event) 366 } 367 return sf.Function.evalInt(event) 368 } 369 370 // EvalReal implements Expression interface. 371 func (sf *ScalarFunction) EvalReal(ctx stochastikctx.Context, event chunk.Event) (float64, bool, error) { 372 return sf.Function.evalReal(event) 373 } 374 375 // EvalDecimal implements Expression interface. 376 func (sf *ScalarFunction) EvalDecimal(ctx stochastikctx.Context, event chunk.Event) (*types.MyDecimal, bool, error) { 377 return sf.Function.evalDecimal(event) 378 } 379 380 // EvalString implements Expression interface. 381 func (sf *ScalarFunction) EvalString(ctx stochastikctx.Context, event chunk.Event) (string, bool, error) { 382 return sf.Function.evalString(event) 383 } 384 385 // EvalTime implements Expression interface. 386 func (sf *ScalarFunction) EvalTime(ctx stochastikctx.Context, event chunk.Event) (types.Time, bool, error) { 387 return sf.Function.evalTime(event) 388 } 389 390 // EvalDuration implements Expression interface. 391 func (sf *ScalarFunction) EvalDuration(ctx stochastikctx.Context, event chunk.Event) (types.Duration, bool, error) { 392 return sf.Function.evalDuration(event) 393 } 394 395 // EvalJSON implements Expression interface. 396 func (sf *ScalarFunction) EvalJSON(ctx stochastikctx.Context, event chunk.Event) (json.BinaryJSON, bool, error) { 397 return sf.Function.evalJSON(event) 398 } 399 400 // HashCode implements Expression interface. 401 func (sf *ScalarFunction) HashCode(sc *stmtctx.StatementContext) []byte { 402 if len(sf.hashcode) > 0 { 403 return sf.hashcode 404 } 405 sf.hashcode = append(sf.hashcode, scalarFunctionFlag) 406 sf.hashcode = codec.EncodeCompactBytes(sf.hashcode, replog.Slice(sf.FuncName.L)) 407 for _, arg := range sf.GetArgs() { 408 sf.hashcode = append(sf.hashcode, arg.HashCode(sc)...) 409 } 410 return sf.hashcode 411 } 412 413 // ResolveIndices implements Expression interface. 414 func (sf *ScalarFunction) ResolveIndices(schemaReplicant *Schema) (Expression, error) { 415 newSf := sf.Clone() 416 err := newSf.resolveIndices(schemaReplicant) 417 return newSf, err 418 } 419 420 func (sf *ScalarFunction) resolveIndices(schemaReplicant *Schema) error { 421 if sf.FuncName.L == ast.In { 422 args := []Expression{} 423 switch inFunc := sf.Function.(type) { 424 case *builtinInIntSig: 425 args = inFunc.nonConstArgs 426 case *builtinInStringSig: 427 args = inFunc.nonConstArgs 428 case *builtinInTimeSig: 429 args = inFunc.nonConstArgs 430 case *builtinInDurationSig: 431 args = inFunc.nonConstArgs 432 case *builtinInRealSig: 433 args = inFunc.nonConstArgs 434 case *builtinInDecimalSig: 435 args = inFunc.nonConstArgs 436 } 437 for _, arg := range args { 438 err := arg.resolveIndices(schemaReplicant) 439 if err != nil { 440 return err 441 } 442 } 443 } 444 for _, arg := range sf.GetArgs() { 445 err := arg.resolveIndices(schemaReplicant) 446 if err != nil { 447 return err 448 } 449 } 450 return nil 451 } 452 453 // GetSingleDeferredCauset returns (DefCaus, Desc) when the ScalarFunction is equivalent to (DefCaus, Desc) 454 // when used as a sort key, otherwise returns (nil, false). 455 // 456 // Can only handle: 457 // - ast.Plus 458 // - ast.Minus 459 // - ast.UnaryMinus 460 func (sf *ScalarFunction) GetSingleDeferredCauset(reverse bool) (*DeferredCauset, bool) { 461 switch sf.FuncName.String() { 462 case ast.Plus: 463 args := sf.GetArgs() 464 switch tp := args[0].(type) { 465 case *DeferredCauset: 466 if _, ok := args[1].(*Constant); !ok { 467 return nil, false 468 } 469 return tp, reverse 470 case *ScalarFunction: 471 if _, ok := args[1].(*Constant); !ok { 472 return nil, false 473 } 474 return tp.GetSingleDeferredCauset(reverse) 475 case *Constant: 476 switch rtp := args[1].(type) { 477 case *DeferredCauset: 478 return rtp, reverse 479 case *ScalarFunction: 480 return rtp.GetSingleDeferredCauset(reverse) 481 } 482 } 483 return nil, false 484 case ast.Minus: 485 args := sf.GetArgs() 486 switch tp := args[0].(type) { 487 case *DeferredCauset: 488 if _, ok := args[1].(*Constant); !ok { 489 return nil, false 490 } 491 return tp, reverse 492 case *ScalarFunction: 493 if _, ok := args[1].(*Constant); !ok { 494 return nil, false 495 } 496 return tp.GetSingleDeferredCauset(reverse) 497 case *Constant: 498 switch rtp := args[1].(type) { 499 case *DeferredCauset: 500 return rtp, !reverse 501 case *ScalarFunction: 502 return rtp.GetSingleDeferredCauset(!reverse) 503 } 504 } 505 return nil, false 506 case ast.UnaryMinus: 507 args := sf.GetArgs() 508 switch tp := args[0].(type) { 509 case *DeferredCauset: 510 return tp, !reverse 511 case *ScalarFunction: 512 return tp.GetSingleDeferredCauset(!reverse) 513 } 514 return nil, false 515 } 516 return nil, false 517 } 518 519 // Coercibility returns the coercibility value which is used to check defCauslations. 520 func (sf *ScalarFunction) Coercibility() Coercibility { 521 if !sf.Function.HasCoercibility() { 522 sf.SetCoercibility(deriveCoercibilityForScarlarFunc(sf)) 523 } 524 return sf.Function.Coercibility() 525 } 526 527 // HasCoercibility ... 528 func (sf *ScalarFunction) HasCoercibility() bool { 529 return sf.Function.HasCoercibility() 530 } 531 532 // SetCoercibility sets a specified coercibility for this memex. 533 func (sf *ScalarFunction) SetCoercibility(val Coercibility) { 534 sf.Function.SetCoercibility(val) 535 } 536 537 // CharsetAndDefCauslation ... 538 func (sf *ScalarFunction) CharsetAndDefCauslation(ctx stochastikctx.Context) (string, string) { 539 return sf.Function.CharsetAndDefCauslation(ctx) 540 } 541 542 // SetCharsetAndDefCauslation ... 543 func (sf *ScalarFunction) SetCharsetAndDefCauslation(chs, defCausl string) { 544 sf.Function.SetCharsetAndDefCauslation(chs, defCausl) 545 }