github.com/whtcorpsinc/milevadb-prod@v0.0.0-20211104133533-f57f4be3b597/dbs/memristed/memex/constant_fold.go (about)

     1  // Copyright 2020 WHTCORPS INC, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package memex
    15  
    16  import (
    17  	"github.com/whtcorpsinc/BerolinaSQL/ast"
    18  	"github.com/whtcorpsinc/milevadb/soliton/chunk"
    19  	"github.com/whtcorpsinc/milevadb/soliton/logutil"
    20  	"go.uber.org/zap"
    21  )
    22  
    23  // specialFoldHandler stores functions for special UDF to constant fold
    24  var specialFoldHandler = map[string]func(*ScalarFunction) (Expression, bool){}
    25  
    26  func init() {
    27  	specialFoldHandler = map[string]func(*ScalarFunction) (Expression, bool){
    28  		ast.If:     ifFoldHandler,
    29  		ast.Ifnull: ifNullFoldHandler,
    30  		ast.Case:   caseWhenHandler,
    31  	}
    32  }
    33  
    34  // FoldConstant does constant folding optimization on an memex excluding deferred ones.
    35  func FoldConstant(expr Expression) Expression {
    36  	e, _ := foldConstant(expr)
    37  	// keep the original coercibility values after folding
    38  	e.SetCoercibility(expr.Coercibility())
    39  	return e
    40  }
    41  
    42  func ifFoldHandler(expr *ScalarFunction) (Expression, bool) {
    43  	args := expr.GetArgs()
    44  	foldedArg0, _ := foldConstant(args[0])
    45  	if constArg, isConst := foldedArg0.(*Constant); isConst {
    46  		arg0, isNull0, err := constArg.EvalInt(expr.Function.getCtx(), chunk.Event{})
    47  		if err != nil {
    48  			// Failed to fold this expr to a constant, print the DEBUG log and
    49  			// return the original memex to let the error to be evaluated
    50  			// again, in that time, the error is returned to the client.
    51  			logutil.BgLogger().Debug("fold memex to constant", zap.String("memex", expr.ExplainInfo()), zap.Error(err))
    52  			return expr, false
    53  		}
    54  		if !isNull0 && arg0 != 0 {
    55  			return foldConstant(args[1])
    56  		}
    57  		return foldConstant(args[2])
    58  	}
    59  	// if the condition is not const, which branch is unknown to run, so directly return.
    60  	return expr, false
    61  }
    62  
    63  func ifNullFoldHandler(expr *ScalarFunction) (Expression, bool) {
    64  	args := expr.GetArgs()
    65  	foldedArg0, isDeferred := foldConstant(args[0])
    66  	if constArg, isConst := foldedArg0.(*Constant); isConst {
    67  		// Only check constArg.Value here. Because deferred memex is
    68  		// evaluated to constArg.Value after foldConstant(args[0]), it's not
    69  		// needed to be checked.
    70  		if constArg.Value.IsNull() {
    71  			return foldConstant(args[1])
    72  		}
    73  		return constArg, isDeferred
    74  	}
    75  	// if the condition is not const, which branch is unknown to run, so directly return.
    76  	return expr, false
    77  }
    78  
    79  func caseWhenHandler(expr *ScalarFunction) (Expression, bool) {
    80  	args, l := expr.GetArgs(), len(expr.GetArgs())
    81  	var isDeferred, isDeferredConst bool
    82  	for i := 0; i < l-1; i += 2 {
    83  		expr.GetArgs()[i], isDeferred = foldConstant(args[i])
    84  		isDeferredConst = isDeferredConst || isDeferred
    85  		if _, isConst := expr.GetArgs()[i].(*Constant); isConst {
    86  			// If the condition is const and true, and the previous conditions
    87  			// has no expr, then the folded execution body is returned, otherwise
    88  			// the arguments of the casewhen are folded and replaced.
    89  			val, isNull, err := args[i].EvalInt(expr.GetCtx(), chunk.Event{})
    90  			if err != nil {
    91  				return expr, false
    92  			}
    93  			if val != 0 && !isNull {
    94  				foldedExpr, isDeferred := foldConstant(args[i+1])
    95  				isDeferredConst = isDeferredConst || isDeferred
    96  				if _, isConst := foldedExpr.(*Constant); isConst {
    97  					foldedExpr.GetType().Decimal = expr.GetType().Decimal
    98  					return foldedExpr, isDeferredConst
    99  				}
   100  				return BuildCastFunction(expr.GetCtx(), foldedExpr, foldedExpr.GetType()), isDeferredConst
   101  			}
   102  		} else {
   103  			// for no-const, here should return directly, because the following branches are unknown to be run or not
   104  			return expr, false
   105  		}
   106  	}
   107  	// If the number of arguments in casewhen is odd, and the previous conditions
   108  	// is false, then the folded else execution body is returned. otherwise
   109  	// the execution body of the else are folded and replaced.
   110  	if l%2 == 1 {
   111  		foldedExpr, isDeferred := foldConstant(args[l-1])
   112  		isDeferredConst = isDeferredConst || isDeferred
   113  		if _, isConst := foldedExpr.(*Constant); isConst {
   114  			foldedExpr.GetType().Decimal = expr.GetType().Decimal
   115  			return foldedExpr, isDeferredConst
   116  		}
   117  		return BuildCastFunction(expr.GetCtx(), foldedExpr, foldedExpr.GetType()), isDeferredConst
   118  	}
   119  	return expr, isDeferredConst
   120  }
   121  
   122  func foldConstant(expr Expression) (Expression, bool) {
   123  	switch x := expr.(type) {
   124  	case *ScalarFunction:
   125  		if _, ok := unFoldableFunctions[x.FuncName.L]; ok {
   126  			return expr, false
   127  		}
   128  		if function := specialFoldHandler[x.FuncName.L]; function != nil {
   129  			return function(x)
   130  		}
   131  
   132  		args := x.GetArgs()
   133  		sc := x.GetCtx().GetStochastikVars().StmtCtx
   134  		argIsConst := make([]bool, len(args))
   135  		hasNullArg := false
   136  		allConstArg := true
   137  		isDeferredConst := false
   138  		for i := 0; i < len(args); i++ {
   139  			switch x := args[i].(type) {
   140  			case *Constant:
   141  				isDeferredConst = isDeferredConst || x.DeferredExpr != nil || x.ParamMarker != nil
   142  				argIsConst[i] = true
   143  				hasNullArg = hasNullArg || x.Value.IsNull()
   144  			default:
   145  				allConstArg = false
   146  			}
   147  		}
   148  		if !allConstArg {
   149  			if !hasNullArg || !sc.InNullRejectCheck || x.FuncName.L == ast.NullEQ {
   150  				return expr, isDeferredConst
   151  			}
   152  			constArgs := make([]Expression, len(args))
   153  			for i, arg := range args {
   154  				if argIsConst[i] {
   155  					constArgs[i] = arg
   156  				} else {
   157  					constArgs[i] = NewOne()
   158  				}
   159  			}
   160  			dummyScalarFunc, err := NewFunctionBase(x.GetCtx(), x.FuncName.L, x.GetType(), constArgs...)
   161  			if err != nil {
   162  				return expr, isDeferredConst
   163  			}
   164  			value, err := dummyScalarFunc.Eval(chunk.Event{})
   165  			if err != nil {
   166  				return expr, isDeferredConst
   167  			}
   168  			if value.IsNull() {
   169  				if isDeferredConst {
   170  					return &Constant{Value: value, RetType: x.RetType, DeferredExpr: x}, true
   171  				}
   172  				return &Constant{Value: value, RetType: x.RetType}, false
   173  			}
   174  			if isTrue, err := value.ToBool(sc); err == nil && isTrue == 0 {
   175  				if isDeferredConst {
   176  					return &Constant{Value: value, RetType: x.RetType, DeferredExpr: x}, true
   177  				}
   178  				return &Constant{Value: value, RetType: x.RetType}, false
   179  			}
   180  			return expr, isDeferredConst
   181  		}
   182  		value, err := x.Eval(chunk.Event{})
   183  		if err != nil {
   184  			logutil.BgLogger().Debug("fold memex to constant", zap.String("memex", x.ExplainInfo()), zap.Error(err))
   185  			return expr, isDeferredConst
   186  		}
   187  		if isDeferredConst {
   188  			return &Constant{Value: value, RetType: x.RetType, DeferredExpr: x}, true
   189  		}
   190  		return &Constant{Value: value, RetType: x.RetType}, false
   191  	case *Constant:
   192  		if x.ParamMarker != nil {
   193  			return &Constant{
   194  				Value:        x.ParamMarker.GetUserVar(),
   195  				RetType:      x.RetType,
   196  				DeferredExpr: x.DeferredExpr,
   197  				ParamMarker:  x.ParamMarker,
   198  			}, true
   199  		} else if x.DeferredExpr != nil {
   200  			value, err := x.DeferredExpr.Eval(chunk.Event{})
   201  			if err != nil {
   202  				logutil.BgLogger().Debug("fold memex to constant", zap.String("memex", x.ExplainInfo()), zap.Error(err))
   203  				return expr, true
   204  			}
   205  			return &Constant{Value: value, RetType: x.RetType, DeferredExpr: x.DeferredExpr}, true
   206  		}
   207  	}
   208  	return expr, false
   209  }