vitess.io/vitess@v0.16.2/go/vt/vtgate/simplifier/expression_simplifier.go (about)

     1  /*
     2  Copyright 2021 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package simplifier
    18  
    19  import (
    20  	"fmt"
    21  	"strconv"
    22  
    23  	"vitess.io/vitess/go/vt/log"
    24  	"vitess.io/vitess/go/vt/sqlparser"
    25  )
    26  
    27  // CheckF is used to see if the given expression exhibits the sought after issue
    28  type CheckF = func(sqlparser.Expr) bool
    29  
    30  func SimplifyExpr(in sqlparser.Expr, test CheckF) (smallestKnown sqlparser.Expr) {
    31  	var maxDepth, level int
    32  	resetTo := func(e sqlparser.Expr) {
    33  		smallestKnown = e
    34  		maxDepth = depth(e)
    35  		level = 0
    36  	}
    37  	resetTo(in)
    38  	for level <= maxDepth {
    39  		current := sqlparser.CloneExpr(smallestKnown)
    40  		nodes, replaceF := getNodesAtLevel(current, level)
    41  		replace := func(e sqlparser.Expr, idx int) {
    42  			// if we are at the first level, we are replacing the root,
    43  			// not rewriting something deep in the tree
    44  			if level == 0 {
    45  				current = e
    46  			} else {
    47  				// replace `node` in current with the simplified expression
    48  				replaceF[idx](e)
    49  			}
    50  		}
    51  		simplified := false
    52  		for idx, node := range nodes {
    53  			// simplify each element and create a new expression with the node replaced by the simplification
    54  			// this means that we not only need the node, but also a way to replace the node
    55  			s := &shrinker{orig: node}
    56  			expr := s.Next()
    57  			for expr != nil {
    58  				replace(expr, idx)
    59  
    60  				valid := test(current)
    61  				log.Errorf("test: %t - %s", valid, sqlparser.String(current))
    62  				if valid {
    63  					simplified = true
    64  					break // we will still continue trying to simplify other expressions at this level
    65  				} else {
    66  					// undo the change
    67  					replace(node, idx)
    68  				}
    69  				expr = s.Next()
    70  			}
    71  		}
    72  		if simplified {
    73  			resetTo(current)
    74  		} else {
    75  			level++
    76  		}
    77  	}
    78  	return smallestKnown
    79  }
    80  
    81  func getNodesAtLevel(e sqlparser.Expr, level int) (result []sqlparser.Expr, replaceF []func(node sqlparser.SQLNode)) {
    82  	lvl := 0
    83  	pre := func(cursor *sqlparser.Cursor) bool {
    84  		if expr, isExpr := cursor.Node().(sqlparser.Expr); level == lvl && isExpr {
    85  			result = append(result, expr)
    86  			replaceF = append(replaceF, cursor.ReplacerF())
    87  		}
    88  		lvl++
    89  		return true
    90  	}
    91  	post := func(cursor *sqlparser.Cursor) bool {
    92  		lvl--
    93  		return true
    94  	}
    95  	sqlparser.Rewrite(e, pre, post)
    96  	return
    97  }
    98  
    99  func depth(e sqlparser.Expr) (depth int) {
   100  	lvl := 0
   101  	pre := func(cursor *sqlparser.Cursor) bool {
   102  		lvl++
   103  		if lvl > depth {
   104  			depth = lvl
   105  		}
   106  		return true
   107  	}
   108  	post := func(cursor *sqlparser.Cursor) bool {
   109  		lvl--
   110  		return true
   111  	}
   112  	sqlparser.Rewrite(e, pre, post)
   113  	return
   114  }
   115  
   116  type shrinker struct {
   117  	orig  sqlparser.Expr
   118  	queue []sqlparser.Expr
   119  }
   120  
   121  func (s *shrinker) Next() sqlparser.Expr {
   122  	for {
   123  		// first we check if there is already something in the queue.
   124  		// note that we are doing a nil check and not a length check here.
   125  		// once something has been added to the queue, we are no longer
   126  		// going to add expressions to the queue
   127  		if s.queue != nil {
   128  			if len(s.queue) == 0 {
   129  				return nil
   130  			}
   131  			nxt := s.queue[0]
   132  			s.queue = s.queue[1:]
   133  			return nxt
   134  		}
   135  		if s.fillQueue() {
   136  			continue
   137  		}
   138  		return nil
   139  	}
   140  }
   141  
   142  func (s *shrinker) fillQueue() bool {
   143  	before := len(s.queue)
   144  	switch e := s.orig.(type) {
   145  	case *sqlparser.ComparisonExpr:
   146  		s.queue = append(s.queue, e.Left, e.Right)
   147  	case *sqlparser.BinaryExpr:
   148  		s.queue = append(s.queue, e.Left, e.Right)
   149  	case *sqlparser.Literal:
   150  		switch e.Type {
   151  		case sqlparser.StrVal:
   152  			half := len(e.Val) / 2
   153  			if half >= 1 {
   154  				s.queue = append(s.queue, &sqlparser.Literal{Type: sqlparser.StrVal, Val: e.Val[:half]})
   155  				s.queue = append(s.queue, &sqlparser.Literal{Type: sqlparser.StrVal, Val: e.Val[half:]})
   156  			} else {
   157  				return false
   158  			}
   159  		case sqlparser.IntVal:
   160  			num, err := strconv.ParseInt(e.Val, 0, 64)
   161  			if err != nil {
   162  				panic(err)
   163  			}
   164  			if num == 0 {
   165  				// can't simplify this more
   166  				return false
   167  			}
   168  
   169  			// we'll simplify by halving the current value and decreasing it by one
   170  			half := num / 2
   171  			oneLess := num - 1
   172  			if num < 0 {
   173  				oneLess = num + 1
   174  			}
   175  
   176  			s.queue = append(s.queue, sqlparser.NewIntLiteral(fmt.Sprintf("%d", half)))
   177  			if oneLess != half {
   178  				s.queue = append(s.queue, sqlparser.NewIntLiteral(fmt.Sprintf("%d", oneLess)))
   179  			}
   180  		case sqlparser.FloatVal, sqlparser.DecimalVal:
   181  			fval, err := strconv.ParseFloat(e.Val, 64)
   182  			if err != nil {
   183  				panic(err)
   184  			}
   185  
   186  			if e.Type == sqlparser.DecimalVal {
   187  				// if it's a decimal, try to simplify as float
   188  				fval := strconv.FormatFloat(fval, 'e', -1, 64)
   189  				s.queue = append(s.queue, sqlparser.NewFloatLiteral(fval))
   190  			}
   191  
   192  			// add the value as an integer
   193  			intval := int(fval)
   194  			s.queue = append(s.queue, sqlparser.NewIntLiteral(fmt.Sprintf("%d", intval)))
   195  
   196  			// we'll simplify by halving the current value and decreasing it by one
   197  			half := fval / 2
   198  			oneLess := fval - 1
   199  			if fval < 0 {
   200  				oneLess = fval + 1
   201  			}
   202  
   203  			s.queue = append(s.queue, sqlparser.NewFloatLiteral(fmt.Sprintf("%f", half)))
   204  			if oneLess != half {
   205  				s.queue = append(s.queue, sqlparser.NewFloatLiteral(fmt.Sprintf("%f", oneLess)))
   206  			}
   207  		default:
   208  			panic(fmt.Sprintf("unhandled literal type %v", e.Type))
   209  		}
   210  	case sqlparser.ValTuple:
   211  		// first we'll try the individual elements first
   212  		for _, v := range e {
   213  			s.queue = append(s.queue, v)
   214  		}
   215  		// then we'll try to use the slice but lacking elements
   216  		for i := range e {
   217  			s.queue = append(s.queue, append(e[:i], e[i+1:]...))
   218  		}
   219  	case *sqlparser.FuncExpr:
   220  		for _, ae := range e.Exprs {
   221  			expr, ok := ae.(*sqlparser.AliasedExpr)
   222  			if !ok {
   223  				continue
   224  			}
   225  			s.queue = append(s.queue, expr.Expr)
   226  		}
   227  	case sqlparser.AggrFunc:
   228  		for _, ae := range e.GetArgs() {
   229  			s.queue = append(s.queue, ae)
   230  		}
   231  	case *sqlparser.ColName:
   232  		// we can try to replace the column with a literal value
   233  		s.queue = []sqlparser.Expr{sqlparser.NewIntLiteral("0")}
   234  	default:
   235  		return false
   236  	}
   237  	return len(s.queue) > before
   238  }