github.com/whtcorpsinc/MilevaDB-Prod@v0.0.0-20211104133533-f57f4be3b597/dbs/memristed/memex/constant_fold.go (about) 1 // Copyright 2020 WHTCORPS INC, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 package memex 15 16 import ( 17 "github.com/whtcorpsinc/BerolinaSQL/ast" 18 "github.com/whtcorpsinc/milevadb/soliton/chunk" 19 "github.com/whtcorpsinc/milevadb/soliton/logutil" 20 "go.uber.org/zap" 21 ) 22 23 // specialFoldHandler stores functions for special UDF to constant fold 24 var specialFoldHandler = map[string]func(*ScalarFunction) (Expression, bool){} 25 26 func init() { 27 specialFoldHandler = map[string]func(*ScalarFunction) (Expression, bool){ 28 ast.If: ifFoldHandler, 29 ast.Ifnull: ifNullFoldHandler, 30 ast.Case: caseWhenHandler, 31 } 32 } 33 34 // FoldConstant does constant folding optimization on an memex excluding deferred ones. 35 func FoldConstant(expr Expression) Expression { 36 e, _ := foldConstant(expr) 37 // keep the original coercibility values after folding 38 e.SetCoercibility(expr.Coercibility()) 39 return e 40 } 41 42 func ifFoldHandler(expr *ScalarFunction) (Expression, bool) { 43 args := expr.GetArgs() 44 foldedArg0, _ := foldConstant(args[0]) 45 if constArg, isConst := foldedArg0.(*Constant); isConst { 46 arg0, isNull0, err := constArg.EvalInt(expr.Function.getCtx(), chunk.Event{}) 47 if err != nil { 48 // Failed to fold this expr to a constant, print the DEBUG log and 49 // return the original memex to let the error to be evaluated 50 // again, in that time, the error is returned to the client. 51 logutil.BgLogger().Debug("fold memex to constant", zap.String("memex", expr.ExplainInfo()), zap.Error(err)) 52 return expr, false 53 } 54 if !isNull0 && arg0 != 0 { 55 return foldConstant(args[1]) 56 } 57 return foldConstant(args[2]) 58 } 59 // if the condition is not const, which branch is unknown to run, so directly return. 60 return expr, false 61 } 62 63 func ifNullFoldHandler(expr *ScalarFunction) (Expression, bool) { 64 args := expr.GetArgs() 65 foldedArg0, isDeferred := foldConstant(args[0]) 66 if constArg, isConst := foldedArg0.(*Constant); isConst { 67 // Only check constArg.Value here. Because deferred memex is 68 // evaluated to constArg.Value after foldConstant(args[0]), it's not 69 // needed to be checked. 70 if constArg.Value.IsNull() { 71 return foldConstant(args[1]) 72 } 73 return constArg, isDeferred 74 } 75 // if the condition is not const, which branch is unknown to run, so directly return. 76 return expr, false 77 } 78 79 func caseWhenHandler(expr *ScalarFunction) (Expression, bool) { 80 args, l := expr.GetArgs(), len(expr.GetArgs()) 81 var isDeferred, isDeferredConst bool 82 for i := 0; i < l-1; i += 2 { 83 expr.GetArgs()[i], isDeferred = foldConstant(args[i]) 84 isDeferredConst = isDeferredConst || isDeferred 85 if _, isConst := expr.GetArgs()[i].(*Constant); isConst { 86 // If the condition is const and true, and the previous conditions 87 // has no expr, then the folded execution body is returned, otherwise 88 // the arguments of the casewhen are folded and replaced. 89 val, isNull, err := args[i].EvalInt(expr.GetCtx(), chunk.Event{}) 90 if err != nil { 91 return expr, false 92 } 93 if val != 0 && !isNull { 94 foldedExpr, isDeferred := foldConstant(args[i+1]) 95 isDeferredConst = isDeferredConst || isDeferred 96 if _, isConst := foldedExpr.(*Constant); isConst { 97 foldedExpr.GetType().Decimal = expr.GetType().Decimal 98 return foldedExpr, isDeferredConst 99 } 100 return BuildCastFunction(expr.GetCtx(), foldedExpr, foldedExpr.GetType()), isDeferredConst 101 } 102 } else { 103 // for no-const, here should return directly, because the following branches are unknown to be run or not 104 return expr, false 105 } 106 } 107 // If the number of arguments in casewhen is odd, and the previous conditions 108 // is false, then the folded else execution body is returned. otherwise 109 // the execution body of the else are folded and replaced. 110 if l%2 == 1 { 111 foldedExpr, isDeferred := foldConstant(args[l-1]) 112 isDeferredConst = isDeferredConst || isDeferred 113 if _, isConst := foldedExpr.(*Constant); isConst { 114 foldedExpr.GetType().Decimal = expr.GetType().Decimal 115 return foldedExpr, isDeferredConst 116 } 117 return BuildCastFunction(expr.GetCtx(), foldedExpr, foldedExpr.GetType()), isDeferredConst 118 } 119 return expr, isDeferredConst 120 } 121 122 func foldConstant(expr Expression) (Expression, bool) { 123 switch x := expr.(type) { 124 case *ScalarFunction: 125 if _, ok := unFoldableFunctions[x.FuncName.L]; ok { 126 return expr, false 127 } 128 if function := specialFoldHandler[x.FuncName.L]; function != nil { 129 return function(x) 130 } 131 132 args := x.GetArgs() 133 sc := x.GetCtx().GetStochastikVars().StmtCtx 134 argIsConst := make([]bool, len(args)) 135 hasNullArg := false 136 allConstArg := true 137 isDeferredConst := false 138 for i := 0; i < len(args); i++ { 139 switch x := args[i].(type) { 140 case *Constant: 141 isDeferredConst = isDeferredConst || x.DeferredExpr != nil || x.ParamMarker != nil 142 argIsConst[i] = true 143 hasNullArg = hasNullArg || x.Value.IsNull() 144 default: 145 allConstArg = false 146 } 147 } 148 if !allConstArg { 149 if !hasNullArg || !sc.InNullRejectCheck || x.FuncName.L == ast.NullEQ { 150 return expr, isDeferredConst 151 } 152 constArgs := make([]Expression, len(args)) 153 for i, arg := range args { 154 if argIsConst[i] { 155 constArgs[i] = arg 156 } else { 157 constArgs[i] = NewOne() 158 } 159 } 160 dummyScalarFunc, err := NewFunctionBase(x.GetCtx(), x.FuncName.L, x.GetType(), constArgs...) 161 if err != nil { 162 return expr, isDeferredConst 163 } 164 value, err := dummyScalarFunc.Eval(chunk.Event{}) 165 if err != nil { 166 return expr, isDeferredConst 167 } 168 if value.IsNull() { 169 if isDeferredConst { 170 return &Constant{Value: value, RetType: x.RetType, DeferredExpr: x}, true 171 } 172 return &Constant{Value: value, RetType: x.RetType}, false 173 } 174 if isTrue, err := value.ToBool(sc); err == nil && isTrue == 0 { 175 if isDeferredConst { 176 return &Constant{Value: value, RetType: x.RetType, DeferredExpr: x}, true 177 } 178 return &Constant{Value: value, RetType: x.RetType}, false 179 } 180 return expr, isDeferredConst 181 } 182 value, err := x.Eval(chunk.Event{}) 183 if err != nil { 184 logutil.BgLogger().Debug("fold memex to constant", zap.String("memex", x.ExplainInfo()), zap.Error(err)) 185 return expr, isDeferredConst 186 } 187 if isDeferredConst { 188 return &Constant{Value: value, RetType: x.RetType, DeferredExpr: x}, true 189 } 190 return &Constant{Value: value, RetType: x.RetType}, false 191 case *Constant: 192 if x.ParamMarker != nil { 193 return &Constant{ 194 Value: x.ParamMarker.GetUserVar(), 195 RetType: x.RetType, 196 DeferredExpr: x.DeferredExpr, 197 ParamMarker: x.ParamMarker, 198 }, true 199 } else if x.DeferredExpr != nil { 200 value, err := x.DeferredExpr.Eval(chunk.Event{}) 201 if err != nil { 202 logutil.BgLogger().Debug("fold memex to constant", zap.String("memex", x.ExplainInfo()), zap.Error(err)) 203 return expr, true 204 } 205 return &Constant{Value: value, RetType: x.RetType, DeferredExpr: x.DeferredExpr}, true 206 } 207 } 208 return expr, false 209 }