vitess.io/vitess@v0.16.2/go/vt/vtgate/simplifier/expression_simplifier.go (about) 1 /* 2 Copyright 2021 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package simplifier 18 19 import ( 20 "fmt" 21 "strconv" 22 23 "vitess.io/vitess/go/vt/log" 24 "vitess.io/vitess/go/vt/sqlparser" 25 ) 26 27 // CheckF is used to see if the given expression exhibits the sought after issue 28 type CheckF = func(sqlparser.Expr) bool 29 30 func SimplifyExpr(in sqlparser.Expr, test CheckF) (smallestKnown sqlparser.Expr) { 31 var maxDepth, level int 32 resetTo := func(e sqlparser.Expr) { 33 smallestKnown = e 34 maxDepth = depth(e) 35 level = 0 36 } 37 resetTo(in) 38 for level <= maxDepth { 39 current := sqlparser.CloneExpr(smallestKnown) 40 nodes, replaceF := getNodesAtLevel(current, level) 41 replace := func(e sqlparser.Expr, idx int) { 42 // if we are at the first level, we are replacing the root, 43 // not rewriting something deep in the tree 44 if level == 0 { 45 current = e 46 } else { 47 // replace `node` in current with the simplified expression 48 replaceF[idx](e) 49 } 50 } 51 simplified := false 52 for idx, node := range nodes { 53 // simplify each element and create a new expression with the node replaced by the simplification 54 // this means that we not only need the node, but also a way to replace the node 55 s := &shrinker{orig: node} 56 expr := s.Next() 57 for expr != nil { 58 replace(expr, idx) 59 60 valid := test(current) 61 log.Errorf("test: %t - %s", valid, sqlparser.String(current)) 62 if valid { 63 simplified = true 64 break // we will still continue trying to simplify other expressions at this level 65 } else { 66 // undo the change 67 replace(node, idx) 68 } 69 expr = s.Next() 70 } 71 } 72 if simplified { 73 resetTo(current) 74 } else { 75 level++ 76 } 77 } 78 return smallestKnown 79 } 80 81 func getNodesAtLevel(e sqlparser.Expr, level int) (result []sqlparser.Expr, replaceF []func(node sqlparser.SQLNode)) { 82 lvl := 0 83 pre := func(cursor *sqlparser.Cursor) bool { 84 if expr, isExpr := cursor.Node().(sqlparser.Expr); level == lvl && isExpr { 85 result = append(result, expr) 86 replaceF = append(replaceF, cursor.ReplacerF()) 87 } 88 lvl++ 89 return true 90 } 91 post := func(cursor *sqlparser.Cursor) bool { 92 lvl-- 93 return true 94 } 95 sqlparser.Rewrite(e, pre, post) 96 return 97 } 98 99 func depth(e sqlparser.Expr) (depth int) { 100 lvl := 0 101 pre := func(cursor *sqlparser.Cursor) bool { 102 lvl++ 103 if lvl > depth { 104 depth = lvl 105 } 106 return true 107 } 108 post := func(cursor *sqlparser.Cursor) bool { 109 lvl-- 110 return true 111 } 112 sqlparser.Rewrite(e, pre, post) 113 return 114 } 115 116 type shrinker struct { 117 orig sqlparser.Expr 118 queue []sqlparser.Expr 119 } 120 121 func (s *shrinker) Next() sqlparser.Expr { 122 for { 123 // first we check if there is already something in the queue. 124 // note that we are doing a nil check and not a length check here. 125 // once something has been added to the queue, we are no longer 126 // going to add expressions to the queue 127 if s.queue != nil { 128 if len(s.queue) == 0 { 129 return nil 130 } 131 nxt := s.queue[0] 132 s.queue = s.queue[1:] 133 return nxt 134 } 135 if s.fillQueue() { 136 continue 137 } 138 return nil 139 } 140 } 141 142 func (s *shrinker) fillQueue() bool { 143 before := len(s.queue) 144 switch e := s.orig.(type) { 145 case *sqlparser.ComparisonExpr: 146 s.queue = append(s.queue, e.Left, e.Right) 147 case *sqlparser.BinaryExpr: 148 s.queue = append(s.queue, e.Left, e.Right) 149 case *sqlparser.Literal: 150 switch e.Type { 151 case sqlparser.StrVal: 152 half := len(e.Val) / 2 153 if half >= 1 { 154 s.queue = append(s.queue, &sqlparser.Literal{Type: sqlparser.StrVal, Val: e.Val[:half]}) 155 s.queue = append(s.queue, &sqlparser.Literal{Type: sqlparser.StrVal, Val: e.Val[half:]}) 156 } else { 157 return false 158 } 159 case sqlparser.IntVal: 160 num, err := strconv.ParseInt(e.Val, 0, 64) 161 if err != nil { 162 panic(err) 163 } 164 if num == 0 { 165 // can't simplify this more 166 return false 167 } 168 169 // we'll simplify by halving the current value and decreasing it by one 170 half := num / 2 171 oneLess := num - 1 172 if num < 0 { 173 oneLess = num + 1 174 } 175 176 s.queue = append(s.queue, sqlparser.NewIntLiteral(fmt.Sprintf("%d", half))) 177 if oneLess != half { 178 s.queue = append(s.queue, sqlparser.NewIntLiteral(fmt.Sprintf("%d", oneLess))) 179 } 180 case sqlparser.FloatVal, sqlparser.DecimalVal: 181 fval, err := strconv.ParseFloat(e.Val, 64) 182 if err != nil { 183 panic(err) 184 } 185 186 if e.Type == sqlparser.DecimalVal { 187 // if it's a decimal, try to simplify as float 188 fval := strconv.FormatFloat(fval, 'e', -1, 64) 189 s.queue = append(s.queue, sqlparser.NewFloatLiteral(fval)) 190 } 191 192 // add the value as an integer 193 intval := int(fval) 194 s.queue = append(s.queue, sqlparser.NewIntLiteral(fmt.Sprintf("%d", intval))) 195 196 // we'll simplify by halving the current value and decreasing it by one 197 half := fval / 2 198 oneLess := fval - 1 199 if fval < 0 { 200 oneLess = fval + 1 201 } 202 203 s.queue = append(s.queue, sqlparser.NewFloatLiteral(fmt.Sprintf("%f", half))) 204 if oneLess != half { 205 s.queue = append(s.queue, sqlparser.NewFloatLiteral(fmt.Sprintf("%f", oneLess))) 206 } 207 default: 208 panic(fmt.Sprintf("unhandled literal type %v", e.Type)) 209 } 210 case sqlparser.ValTuple: 211 // first we'll try the individual elements first 212 for _, v := range e { 213 s.queue = append(s.queue, v) 214 } 215 // then we'll try to use the slice but lacking elements 216 for i := range e { 217 s.queue = append(s.queue, append(e[:i], e[i+1:]...)) 218 } 219 case *sqlparser.FuncExpr: 220 for _, ae := range e.Exprs { 221 expr, ok := ae.(*sqlparser.AliasedExpr) 222 if !ok { 223 continue 224 } 225 s.queue = append(s.queue, expr.Expr) 226 } 227 case sqlparser.AggrFunc: 228 for _, ae := range e.GetArgs() { 229 s.queue = append(s.queue, ae) 230 } 231 case *sqlparser.ColName: 232 // we can try to replace the column with a literal value 233 s.queue = []sqlparser.Expr{sqlparser.NewIntLiteral("0")} 234 default: 235 return false 236 } 237 return len(s.queue) > before 238 }