github.com/whtcorpsinc/milevadb-prod@v0.0.0-20211104133533-f57f4be3b597/dbs/memristed/memex/scalar_function.go (about)

     1  // Copyright 2020 WHTCORPS INC, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package memex
    15  
    16  import (
    17  	"bytes"
    18  	"fmt"
    19  
    20  	"github.com/whtcorpsinc/errors"
    21  	"github.com/whtcorpsinc/BerolinaSQL/ast"
    22  	"github.com/whtcorpsinc/BerolinaSQL/perceptron"
    23  	"github.com/whtcorpsinc/BerolinaSQL/allegrosql"
    24  	"github.com/whtcorpsinc/BerolinaSQL/terror"
    25  	"github.com/whtcorpsinc/milevadb/stochastikctx"
    26  	"github.com/whtcorpsinc/milevadb/stochastikctx/stmtctx"
    27  	"github.com/whtcorpsinc/milevadb/types"
    28  	"github.com/whtcorpsinc/milevadb/types/json"
    29  	"github.com/whtcorpsinc/milevadb/soliton/chunk"
    30  	"github.com/whtcorpsinc/milevadb/soliton/codec"
    31  	"github.com/whtcorpsinc/milevadb/soliton/replog"
    32  )
    33  
    34  // error definitions.
    35  var (
    36  	ErrNoDB = terror.ClassOptimizer.New(allegrosql.ErrNoDB, allegrosql.MyALLEGROSQLErrName[allegrosql.ErrNoDB])
    37  )
    38  
    39  // ScalarFunction is the function that returns a value.
    40  type ScalarFunction struct {
    41  	FuncName perceptron.CIStr
    42  	// RetType is the type that ScalarFunction returns.
    43  	// TODO: Implement type inference here, now we use ast's return type temporarily.
    44  	RetType  *types.FieldType
    45  	Function builtinFunc
    46  	hashcode []byte
    47  }
    48  
    49  // VecEvalInt evaluates this memex in a vectorized manner.
    50  func (sf *ScalarFunction) VecEvalInt(ctx stochastikctx.Context, input *chunk.Chunk, result *chunk.DeferredCauset) error {
    51  	return sf.Function.vecEvalInt(input, result)
    52  }
    53  
    54  // VecEvalReal evaluates this memex in a vectorized manner.
    55  func (sf *ScalarFunction) VecEvalReal(ctx stochastikctx.Context, input *chunk.Chunk, result *chunk.DeferredCauset) error {
    56  	return sf.Function.vecEvalReal(input, result)
    57  }
    58  
    59  // VecEvalString evaluates this memex in a vectorized manner.
    60  func (sf *ScalarFunction) VecEvalString(ctx stochastikctx.Context, input *chunk.Chunk, result *chunk.DeferredCauset) error {
    61  	return sf.Function.vecEvalString(input, result)
    62  }
    63  
    64  // VecEvalDecimal evaluates this memex in a vectorized manner.
    65  func (sf *ScalarFunction) VecEvalDecimal(ctx stochastikctx.Context, input *chunk.Chunk, result *chunk.DeferredCauset) error {
    66  	return sf.Function.vecEvalDecimal(input, result)
    67  }
    68  
    69  // VecEvalTime evaluates this memex in a vectorized manner.
    70  func (sf *ScalarFunction) VecEvalTime(ctx stochastikctx.Context, input *chunk.Chunk, result *chunk.DeferredCauset) error {
    71  	return sf.Function.vecEvalTime(input, result)
    72  }
    73  
    74  // VecEvalDuration evaluates this memex in a vectorized manner.
    75  func (sf *ScalarFunction) VecEvalDuration(ctx stochastikctx.Context, input *chunk.Chunk, result *chunk.DeferredCauset) error {
    76  	return sf.Function.vecEvalDuration(input, result)
    77  }
    78  
    79  // VecEvalJSON evaluates this memex in a vectorized manner.
    80  func (sf *ScalarFunction) VecEvalJSON(ctx stochastikctx.Context, input *chunk.Chunk, result *chunk.DeferredCauset) error {
    81  	return sf.Function.vecEvalJSON(input, result)
    82  }
    83  
    84  // GetArgs gets arguments of function.
    85  func (sf *ScalarFunction) GetArgs() []Expression {
    86  	return sf.Function.getArgs()
    87  }
    88  
    89  // Vectorized returns if this memex supports vectorized evaluation.
    90  func (sf *ScalarFunction) Vectorized() bool {
    91  	return sf.Function.vectorized() && sf.Function.isChildrenVectorized()
    92  }
    93  
    94  // SupportReverseEval returns if this memex supports reversed evaluation.
    95  func (sf *ScalarFunction) SupportReverseEval() bool {
    96  	switch sf.RetType.Tp {
    97  	case allegrosql.TypeShort, allegrosql.TypeLong, allegrosql.TypeLonglong,
    98  		allegrosql.TypeFloat, allegrosql.TypeDouble, allegrosql.TypeNewDecimal:
    99  		return sf.Function.supportReverseEval() && sf.Function.isChildrenReversed()
   100  	}
   101  	return false
   102  }
   103  
   104  // ReverseEval evaluates the only one defCausumn value with given function result.
   105  func (sf *ScalarFunction) ReverseEval(sc *stmtctx.StatementContext, res types.Causet, rType types.RoundingType) (val types.Causet, err error) {
   106  	return sf.Function.reverseEval(sc, res, rType)
   107  }
   108  
   109  // GetCtx gets the context of function.
   110  func (sf *ScalarFunction) GetCtx() stochastikctx.Context {
   111  	return sf.Function.getCtx()
   112  }
   113  
   114  // String implements fmt.Stringer interface.
   115  func (sf *ScalarFunction) String() string {
   116  	var buffer bytes.Buffer
   117  	fmt.Fprintf(&buffer, "%s(", sf.FuncName.L)
   118  	switch sf.FuncName.L {
   119  	case ast.Cast:
   120  		for _, arg := range sf.GetArgs() {
   121  			buffer.WriteString(arg.String())
   122  			buffer.WriteString(", ")
   123  			buffer.WriteString(sf.RetType.String())
   124  		}
   125  	default:
   126  		for i, arg := range sf.GetArgs() {
   127  			buffer.WriteString(arg.String())
   128  			if i+1 != len(sf.GetArgs()) {
   129  				buffer.WriteString(", ")
   130  			}
   131  		}
   132  	}
   133  	buffer.WriteString(")")
   134  	return buffer.String()
   135  }
   136  
   137  // MarshalJSON implements json.Marshaler interface.
   138  func (sf *ScalarFunction) MarshalJSON() ([]byte, error) {
   139  	return []byte(fmt.Sprintf("%q", sf)), nil
   140  }
   141  
   142  // typeInferForNull infers the NULL constants field type and set the field type
   143  // of NULL constant same as other non-null operands.
   144  func typeInferForNull(args []Expression) {
   145  	if len(args) < 2 {
   146  		return
   147  	}
   148  	var isNull = func(expr Expression) bool {
   149  		cons, ok := expr.(*Constant)
   150  		return ok && cons.RetType.Tp == allegrosql.TypeNull && cons.Value.IsNull()
   151  	}
   152  	// Infer the actual field type of the NULL constant.
   153  	var retFieldTp *types.FieldType
   154  	var hasNullArg bool
   155  	for _, arg := range args {
   156  		isNullArg := isNull(arg)
   157  		if !isNullArg && retFieldTp == nil {
   158  			retFieldTp = arg.GetType()
   159  		}
   160  		hasNullArg = hasNullArg || isNullArg
   161  		// Break if there are both NULL and non-NULL memex
   162  		if hasNullArg && retFieldTp != nil {
   163  			break
   164  		}
   165  	}
   166  	if !hasNullArg || retFieldTp == nil {
   167  		return
   168  	}
   169  	for _, arg := range args {
   170  		if isNull(arg) {
   171  			*arg.GetType() = *retFieldTp
   172  		}
   173  	}
   174  }
   175  
   176  // newFunctionImpl creates a new scalar function or constant.
   177  // fold: 1 means folding constants, while 0 means not,
   178  // -1 means try to fold constants if without errors/warnings, otherwise not.
   179  func newFunctionImpl(ctx stochastikctx.Context, fold int, funcName string, retType *types.FieldType, args ...Expression) (Expression, error) {
   180  	if retType == nil {
   181  		return nil, errors.Errorf("RetType cannot be nil for ScalarFunction.")
   182  	}
   183  	if funcName == ast.Cast {
   184  		return BuildCastFunction(ctx, args[0], retType), nil
   185  	}
   186  	fc, ok := funcs[funcName]
   187  	if !ok {
   188  		EDB := ctx.GetStochastikVars().CurrentDB
   189  		if EDB == "" {
   190  			return nil, errors.Trace(ErrNoDB)
   191  		}
   192  
   193  		return nil, errFunctionNotExists.GenWithStackByArgs("FUNCTION", EDB+"."+funcName)
   194  	}
   195  	if !ctx.GetStochastikVars().EnableNoopFuncs {
   196  		if _, ok := noopFuncs[funcName]; ok {
   197  			return nil, ErrFunctionsNoopImpl.GenWithStackByArgs(funcName)
   198  		}
   199  	}
   200  	funcArgs := make([]Expression, len(args))
   201  	copy(funcArgs, args)
   202  	typeInferForNull(funcArgs)
   203  	f, err := fc.getFunction(ctx, funcArgs)
   204  	if err != nil {
   205  		return nil, err
   206  	}
   207  	if builtinRetTp := f.getRetTp(); builtinRetTp.Tp != allegrosql.TypeUnspecified || retType.Tp == allegrosql.TypeUnspecified {
   208  		retType = builtinRetTp
   209  	}
   210  	sf := &ScalarFunction{
   211  		FuncName: perceptron.NewCIStr(funcName),
   212  		RetType:  retType,
   213  		Function: f,
   214  	}
   215  	if fold == 1 {
   216  		return FoldConstant(sf), nil
   217  	} else if fold == -1 {
   218  		// try to fold constants, and return the original function if errors/warnings occur
   219  		sc := ctx.GetStochastikVars().StmtCtx
   220  		beforeWarns := sc.WarningCount()
   221  		newSf := FoldConstant(sf)
   222  		afterWarns := sc.WarningCount()
   223  		if afterWarns > beforeWarns {
   224  			sc.TruncateWarnings(int(beforeWarns))
   225  			return sf, nil
   226  		}
   227  		return newSf, nil
   228  	}
   229  	return sf, nil
   230  }
   231  
   232  // NewFunction creates a new scalar function or constant via a constant folding.
   233  func NewFunction(ctx stochastikctx.Context, funcName string, retType *types.FieldType, args ...Expression) (Expression, error) {
   234  	return newFunctionImpl(ctx, 1, funcName, retType, args...)
   235  }
   236  
   237  // NewFunctionBase creates a new scalar function with no constant folding.
   238  func NewFunctionBase(ctx stochastikctx.Context, funcName string, retType *types.FieldType, args ...Expression) (Expression, error) {
   239  	return newFunctionImpl(ctx, 0, funcName, retType, args...)
   240  }
   241  
   242  // NewFunctionTryFold creates a new scalar function with trying constant folding.
   243  func NewFunctionTryFold(ctx stochastikctx.Context, funcName string, retType *types.FieldType, args ...Expression) (Expression, error) {
   244  	return newFunctionImpl(ctx, -1, funcName, retType, args...)
   245  }
   246  
   247  // NewFunctionInternal is similar to NewFunction, but do not returns error, should only be used internally.
   248  func NewFunctionInternal(ctx stochastikctx.Context, funcName string, retType *types.FieldType, args ...Expression) Expression {
   249  	expr, err := NewFunction(ctx, funcName, retType, args...)
   250  	terror.Log(err)
   251  	return expr
   252  }
   253  
   254  // ScalarFuncs2Exprs converts []*ScalarFunction to []Expression.
   255  func ScalarFuncs2Exprs(funcs []*ScalarFunction) []Expression {
   256  	result := make([]Expression, 0, len(funcs))
   257  	for _, defCaus := range funcs {
   258  		result = append(result, defCaus)
   259  	}
   260  	return result
   261  }
   262  
   263  // Clone implements Expression interface.
   264  func (sf *ScalarFunction) Clone() Expression {
   265  	c := &ScalarFunction{
   266  		FuncName: sf.FuncName,
   267  		RetType:  sf.RetType,
   268  		Function: sf.Function.Clone(),
   269  		hashcode: sf.hashcode,
   270  	}
   271  	c.SetCharsetAndDefCauslation(sf.CharsetAndDefCauslation(sf.GetCtx()))
   272  	c.SetCoercibility(sf.Coercibility())
   273  	return c
   274  }
   275  
   276  // GetType implements Expression interface.
   277  func (sf *ScalarFunction) GetType() *types.FieldType {
   278  	return sf.RetType
   279  }
   280  
   281  // Equal implements Expression interface.
   282  func (sf *ScalarFunction) Equal(ctx stochastikctx.Context, e Expression) bool {
   283  	fun, ok := e.(*ScalarFunction)
   284  	if !ok {
   285  		return false
   286  	}
   287  	if sf.FuncName.L != fun.FuncName.L {
   288  		return false
   289  	}
   290  	return sf.Function.equal(fun.Function)
   291  }
   292  
   293  // IsCorrelated implements Expression interface.
   294  func (sf *ScalarFunction) IsCorrelated() bool {
   295  	for _, arg := range sf.GetArgs() {
   296  		if arg.IsCorrelated() {
   297  			return true
   298  		}
   299  	}
   300  	return false
   301  }
   302  
   303  // ConstItem implements Expression interface.
   304  func (sf *ScalarFunction) ConstItem(sc *stmtctx.StatementContext) bool {
   305  	// Note: some unfoldable functions are deterministic, we use unFoldableFunctions here for simplification.
   306  	if _, ok := unFoldableFunctions[sf.FuncName.L]; ok {
   307  		return false
   308  	}
   309  	for _, arg := range sf.GetArgs() {
   310  		if !arg.ConstItem(sc) {
   311  			return false
   312  		}
   313  	}
   314  	return true
   315  }
   316  
   317  // Decorrelate implements Expression interface.
   318  func (sf *ScalarFunction) Decorrelate(schemaReplicant *Schema) Expression {
   319  	for i, arg := range sf.GetArgs() {
   320  		sf.GetArgs()[i] = arg.Decorrelate(schemaReplicant)
   321  	}
   322  	return sf
   323  }
   324  
   325  // Eval implements Expression interface.
   326  func (sf *ScalarFunction) Eval(event chunk.Event) (d types.Causet, err error) {
   327  	var (
   328  		res    interface{}
   329  		isNull bool
   330  	)
   331  	switch tp, evalType := sf.GetType(), sf.GetType().EvalType(); evalType {
   332  	case types.ETInt:
   333  		var intRes int64
   334  		intRes, isNull, err = sf.EvalInt(sf.GetCtx(), event)
   335  		if allegrosql.HasUnsignedFlag(tp.Flag) {
   336  			res = uint64(intRes)
   337  		} else {
   338  			res = intRes
   339  		}
   340  	case types.ETReal:
   341  		res, isNull, err = sf.EvalReal(sf.GetCtx(), event)
   342  	case types.ETDecimal:
   343  		res, isNull, err = sf.EvalDecimal(sf.GetCtx(), event)
   344  	case types.ETDatetime, types.ETTimestamp:
   345  		res, isNull, err = sf.EvalTime(sf.GetCtx(), event)
   346  	case types.ETDuration:
   347  		res, isNull, err = sf.EvalDuration(sf.GetCtx(), event)
   348  	case types.ETJson:
   349  		res, isNull, err = sf.EvalJSON(sf.GetCtx(), event)
   350  	case types.ETString:
   351  		res, isNull, err = sf.EvalString(sf.GetCtx(), event)
   352  	}
   353  
   354  	if isNull || err != nil {
   355  		d.SetNull()
   356  		return d, err
   357  	}
   358  	d.SetValue(res, sf.RetType)
   359  	return
   360  }
   361  
   362  // EvalInt implements Expression interface.
   363  func (sf *ScalarFunction) EvalInt(ctx stochastikctx.Context, event chunk.Event) (int64, bool, error) {
   364  	if f, ok := sf.Function.(builtinFuncNew); ok {
   365  		return f.evalIntWithCtx(ctx, event)
   366  	}
   367  	return sf.Function.evalInt(event)
   368  }
   369  
   370  // EvalReal implements Expression interface.
   371  func (sf *ScalarFunction) EvalReal(ctx stochastikctx.Context, event chunk.Event) (float64, bool, error) {
   372  	return sf.Function.evalReal(event)
   373  }
   374  
   375  // EvalDecimal implements Expression interface.
   376  func (sf *ScalarFunction) EvalDecimal(ctx stochastikctx.Context, event chunk.Event) (*types.MyDecimal, bool, error) {
   377  	return sf.Function.evalDecimal(event)
   378  }
   379  
   380  // EvalString implements Expression interface.
   381  func (sf *ScalarFunction) EvalString(ctx stochastikctx.Context, event chunk.Event) (string, bool, error) {
   382  	return sf.Function.evalString(event)
   383  }
   384  
   385  // EvalTime implements Expression interface.
   386  func (sf *ScalarFunction) EvalTime(ctx stochastikctx.Context, event chunk.Event) (types.Time, bool, error) {
   387  	return sf.Function.evalTime(event)
   388  }
   389  
   390  // EvalDuration implements Expression interface.
   391  func (sf *ScalarFunction) EvalDuration(ctx stochastikctx.Context, event chunk.Event) (types.Duration, bool, error) {
   392  	return sf.Function.evalDuration(event)
   393  }
   394  
   395  // EvalJSON implements Expression interface.
   396  func (sf *ScalarFunction) EvalJSON(ctx stochastikctx.Context, event chunk.Event) (json.BinaryJSON, bool, error) {
   397  	return sf.Function.evalJSON(event)
   398  }
   399  
   400  // HashCode implements Expression interface.
   401  func (sf *ScalarFunction) HashCode(sc *stmtctx.StatementContext) []byte {
   402  	if len(sf.hashcode) > 0 {
   403  		return sf.hashcode
   404  	}
   405  	sf.hashcode = append(sf.hashcode, scalarFunctionFlag)
   406  	sf.hashcode = codec.EncodeCompactBytes(sf.hashcode, replog.Slice(sf.FuncName.L))
   407  	for _, arg := range sf.GetArgs() {
   408  		sf.hashcode = append(sf.hashcode, arg.HashCode(sc)...)
   409  	}
   410  	return sf.hashcode
   411  }
   412  
   413  // ResolveIndices implements Expression interface.
   414  func (sf *ScalarFunction) ResolveIndices(schemaReplicant *Schema) (Expression, error) {
   415  	newSf := sf.Clone()
   416  	err := newSf.resolveIndices(schemaReplicant)
   417  	return newSf, err
   418  }
   419  
   420  func (sf *ScalarFunction) resolveIndices(schemaReplicant *Schema) error {
   421  	if sf.FuncName.L == ast.In {
   422  		args := []Expression{}
   423  		switch inFunc := sf.Function.(type) {
   424  		case *builtinInIntSig:
   425  			args = inFunc.nonConstArgs
   426  		case *builtinInStringSig:
   427  			args = inFunc.nonConstArgs
   428  		case *builtinInTimeSig:
   429  			args = inFunc.nonConstArgs
   430  		case *builtinInDurationSig:
   431  			args = inFunc.nonConstArgs
   432  		case *builtinInRealSig:
   433  			args = inFunc.nonConstArgs
   434  		case *builtinInDecimalSig:
   435  			args = inFunc.nonConstArgs
   436  		}
   437  		for _, arg := range args {
   438  			err := arg.resolveIndices(schemaReplicant)
   439  			if err != nil {
   440  				return err
   441  			}
   442  		}
   443  	}
   444  	for _, arg := range sf.GetArgs() {
   445  		err := arg.resolveIndices(schemaReplicant)
   446  		if err != nil {
   447  			return err
   448  		}
   449  	}
   450  	return nil
   451  }
   452  
   453  // GetSingleDeferredCauset returns (DefCaus, Desc) when the ScalarFunction is equivalent to (DefCaus, Desc)
   454  // when used as a sort key, otherwise returns (nil, false).
   455  //
   456  // Can only handle:
   457  // - ast.Plus
   458  // - ast.Minus
   459  // - ast.UnaryMinus
   460  func (sf *ScalarFunction) GetSingleDeferredCauset(reverse bool) (*DeferredCauset, bool) {
   461  	switch sf.FuncName.String() {
   462  	case ast.Plus:
   463  		args := sf.GetArgs()
   464  		switch tp := args[0].(type) {
   465  		case *DeferredCauset:
   466  			if _, ok := args[1].(*Constant); !ok {
   467  				return nil, false
   468  			}
   469  			return tp, reverse
   470  		case *ScalarFunction:
   471  			if _, ok := args[1].(*Constant); !ok {
   472  				return nil, false
   473  			}
   474  			return tp.GetSingleDeferredCauset(reverse)
   475  		case *Constant:
   476  			switch rtp := args[1].(type) {
   477  			case *DeferredCauset:
   478  				return rtp, reverse
   479  			case *ScalarFunction:
   480  				return rtp.GetSingleDeferredCauset(reverse)
   481  			}
   482  		}
   483  		return nil, false
   484  	case ast.Minus:
   485  		args := sf.GetArgs()
   486  		switch tp := args[0].(type) {
   487  		case *DeferredCauset:
   488  			if _, ok := args[1].(*Constant); !ok {
   489  				return nil, false
   490  			}
   491  			return tp, reverse
   492  		case *ScalarFunction:
   493  			if _, ok := args[1].(*Constant); !ok {
   494  				return nil, false
   495  			}
   496  			return tp.GetSingleDeferredCauset(reverse)
   497  		case *Constant:
   498  			switch rtp := args[1].(type) {
   499  			case *DeferredCauset:
   500  				return rtp, !reverse
   501  			case *ScalarFunction:
   502  				return rtp.GetSingleDeferredCauset(!reverse)
   503  			}
   504  		}
   505  		return nil, false
   506  	case ast.UnaryMinus:
   507  		args := sf.GetArgs()
   508  		switch tp := args[0].(type) {
   509  		case *DeferredCauset:
   510  			return tp, !reverse
   511  		case *ScalarFunction:
   512  			return tp.GetSingleDeferredCauset(!reverse)
   513  		}
   514  		return nil, false
   515  	}
   516  	return nil, false
   517  }
   518  
   519  // Coercibility returns the coercibility value which is used to check defCauslations.
   520  func (sf *ScalarFunction) Coercibility() Coercibility {
   521  	if !sf.Function.HasCoercibility() {
   522  		sf.SetCoercibility(deriveCoercibilityForScarlarFunc(sf))
   523  	}
   524  	return sf.Function.Coercibility()
   525  }
   526  
   527  // HasCoercibility ...
   528  func (sf *ScalarFunction) HasCoercibility() bool {
   529  	return sf.Function.HasCoercibility()
   530  }
   531  
   532  // SetCoercibility sets a specified coercibility for this memex.
   533  func (sf *ScalarFunction) SetCoercibility(val Coercibility) {
   534  	sf.Function.SetCoercibility(val)
   535  }
   536  
   537  // CharsetAndDefCauslation ...
   538  func (sf *ScalarFunction) CharsetAndDefCauslation(ctx stochastikctx.Context) (string, string) {
   539  	return sf.Function.CharsetAndDefCauslation(ctx)
   540  }
   541  
   542  // SetCharsetAndDefCauslation ...
   543  func (sf *ScalarFunction) SetCharsetAndDefCauslation(chs, defCausl string) {
   544  	sf.Function.SetCharsetAndDefCauslation(chs, defCausl)
   545  }