vitess.io/vitess@v0.16.2/go/vt/vtgate/simplifier/simplifier.go (about)

     1  /*
     2  Copyright 2021 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package simplifier
    18  
    19  import (
    20  	"vitess.io/vitess/go/vt/log"
    21  	"vitess.io/vitess/go/vt/sqlparser"
    22  	"vitess.io/vitess/go/vt/vtgate/semantics"
    23  )
    24  
    25  // SimplifyStatement simplifies the AST of a query. It basically iteratively prunes leaves of the AST, as long as the pruning
    26  // continues to return true from the `test` function.
    27  func SimplifyStatement(
    28  	in sqlparser.SelectStatement,
    29  	currentDB string,
    30  	si semantics.SchemaInformation,
    31  	testF func(sqlparser.SelectStatement) bool,
    32  ) sqlparser.SelectStatement {
    33  	tables, err := getTables(in, currentDB, si)
    34  	if err != nil {
    35  		panic(err)
    36  	}
    37  
    38  	test := func(s sqlparser.SelectStatement) bool {
    39  		// Since our semantic analysis changes the AST, we clone it first, so we have a pristine AST to play with
    40  		return testF(sqlparser.CloneSelectStatement(s))
    41  	}
    42  
    43  	if success := trySimplifyUnions(sqlparser.CloneSelectStatement(in), test); success != nil {
    44  		return SimplifyStatement(success, currentDB, si, testF)
    45  	}
    46  
    47  	// first we try to simplify the query by removing any table.
    48  	// If we can remove a table and all uses of it, that's a good start
    49  	if success := tryRemoveTable(tables, sqlparser.CloneSelectStatement(in), currentDB, si, testF); success != nil {
    50  		return SimplifyStatement(success, currentDB, si, testF)
    51  	}
    52  
    53  	// now let's try to simplify * expressions
    54  	if success := simplifyStarExpr(sqlparser.CloneSelectStatement(in), test); success != nil {
    55  		return SimplifyStatement(success, currentDB, si, testF)
    56  	}
    57  
    58  	// we try to remove select expressions next
    59  	if success := trySimplifyExpressions(sqlparser.CloneSelectStatement(in), test); success != nil {
    60  		return SimplifyStatement(success, currentDB, si, testF)
    61  	}
    62  	return in
    63  }
    64  
    65  func trySimplifyExpressions(in sqlparser.SelectStatement, test func(sqlparser.SelectStatement) bool) sqlparser.SelectStatement {
    66  	simplified := false
    67  	visitAllExpressionsInAST(in, func(cursor expressionCursor) bool {
    68  		// first - let's try to remove the expression
    69  		if cursor.remove() {
    70  			if test(in) {
    71  				log.Errorf("removed expression: %s", sqlparser.String(cursor.expr))
    72  				simplified = true
    73  				return false
    74  			}
    75  			cursor.restore()
    76  		}
    77  
    78  		// ok, we seem to need this expression. let's see if we can find a simpler version
    79  		s := &shrinker{orig: cursor.expr}
    80  		newExpr := s.Next()
    81  		for newExpr != nil {
    82  			cursor.replace(newExpr)
    83  			if test(in) {
    84  				log.Errorf("simplified expression: %s -> %s", sqlparser.String(cursor.expr), sqlparser.String(newExpr))
    85  				simplified = true
    86  				return false
    87  			}
    88  			newExpr = s.Next()
    89  		}
    90  
    91  		// if we get here, we failed to simplify this expression,
    92  		// so we put back in the original expression
    93  		cursor.restore()
    94  		return true
    95  	})
    96  
    97  	if simplified {
    98  		return in
    99  	}
   100  
   101  	return nil
   102  }
   103  
   104  func trySimplifyUnions(in sqlparser.SelectStatement, test func(sqlparser.SelectStatement) bool) (res sqlparser.SelectStatement) {
   105  
   106  	if union, ok := in.(*sqlparser.Union); ok {
   107  		// the root object is an UNION
   108  		if test(sqlparser.CloneSelectStatement(union.Left)) {
   109  			return union.Left
   110  		}
   111  		if test(sqlparser.CloneSelectStatement(union.Right)) {
   112  			return union.Right
   113  		}
   114  	}
   115  
   116  	abort := false
   117  
   118  	sqlparser.Rewrite(in, func(cursor *sqlparser.Cursor) bool {
   119  		switch node := cursor.Node().(type) {
   120  		case *sqlparser.Union:
   121  			if _, ok := cursor.Parent().(*sqlparser.RootNode); ok {
   122  				// we have already checked the root node
   123  				return true
   124  			}
   125  			cursor.Replace(node.Left)
   126  			clone := sqlparser.CloneSelectStatement(in)
   127  			if test(clone) {
   128  				log.Errorf("replaced UNION with one of its children")
   129  				abort = true
   130  				return true
   131  			}
   132  			cursor.Replace(node.Right)
   133  			clone = sqlparser.CloneSelectStatement(in)
   134  			if test(clone) {
   135  				log.Errorf("replaced UNION with one of its children")
   136  				abort = true
   137  				return true
   138  			}
   139  			cursor.Replace(node)
   140  		}
   141  		return true
   142  	}, func(*sqlparser.Cursor) bool {
   143  		return !abort
   144  	})
   145  
   146  	if !abort {
   147  		// we found no simplifications
   148  		return nil
   149  	}
   150  	return in
   151  }
   152  
   153  func tryRemoveTable(tables []semantics.TableInfo, in sqlparser.SelectStatement, currentDB string, si semantics.SchemaInformation, test func(sqlparser.SelectStatement) bool) sqlparser.SelectStatement {
   154  	// we start by removing one table at a time, and see if we still have an interesting plan
   155  	for idx, tbl := range tables {
   156  		clone := sqlparser.CloneSelectStatement(in)
   157  		searchedTS := semantics.SingleTableSet(idx)
   158  		simplified := removeTable(clone, searchedTS, currentDB, si)
   159  		name, _ := tbl.Name()
   160  		if simplified && test(clone) {
   161  			log.Errorf("removed table %s", sqlparser.String(name))
   162  			return clone
   163  		}
   164  	}
   165  
   166  	return nil
   167  }
   168  
   169  func getTables(in sqlparser.SelectStatement, currentDB string, si semantics.SchemaInformation) ([]semantics.TableInfo, error) {
   170  	// Since our semantic analysis changes the AST, we clone it first, so we have a pristine AST to play with
   171  	clone := sqlparser.CloneSelectStatement(in)
   172  	semTable, err := semantics.Analyze(clone, currentDB, si)
   173  	if err != nil {
   174  		return nil, err
   175  	}
   176  	return semTable.Tables, nil
   177  }
   178  
   179  func simplifyStarExpr(in sqlparser.SelectStatement, test func(sqlparser.SelectStatement) bool) sqlparser.SelectStatement {
   180  	simplified := false
   181  	sqlparser.Rewrite(in, func(cursor *sqlparser.Cursor) bool {
   182  		se, ok := cursor.Node().(*sqlparser.StarExpr)
   183  		if !ok {
   184  			return true
   185  		}
   186  		cursor.Replace(&sqlparser.AliasedExpr{
   187  			Expr: sqlparser.NewIntLiteral("0"),
   188  		})
   189  		if test(in) {
   190  			log.Errorf("replaced star with literal")
   191  			simplified = true
   192  			return false
   193  		}
   194  		cursor.Replace(se)
   195  
   196  		return true
   197  	}, nil)
   198  	if simplified {
   199  		return in
   200  	}
   201  	return nil
   202  }
   203  
   204  // removeTable removes the table with the given index from the select statement, which includes the FROM clause
   205  // but also all expressions and predicates that depend on the table
   206  func removeTable(clone sqlparser.SelectStatement, searchedTS semantics.TableSet, db string, si semantics.SchemaInformation) bool {
   207  	semTable, err := semantics.Analyze(clone, db, si)
   208  	if err != nil {
   209  		panic(err)
   210  	}
   211  
   212  	simplified := true
   213  	shouldKeepExpr := func(expr sqlparser.Expr) bool {
   214  		return !semTable.RecursiveDeps(expr).IsOverlapping(searchedTS) || sqlparser.ContainsAggregation(expr)
   215  	}
   216  	sqlparser.Rewrite(clone, func(cursor *sqlparser.Cursor) bool {
   217  		switch node := cursor.Node().(type) {
   218  		case *sqlparser.JoinTableExpr:
   219  			lft, ok := node.LeftExpr.(*sqlparser.AliasedTableExpr)
   220  			if ok {
   221  				ts := semTable.TableSetFor(lft)
   222  				if searchedTS == ts {
   223  					cursor.Replace(node.RightExpr)
   224  				}
   225  			}
   226  			rgt, ok := node.RightExpr.(*sqlparser.AliasedTableExpr)
   227  			if ok {
   228  				ts := semTable.TableSetFor(rgt)
   229  				if searchedTS == ts {
   230  					cursor.Replace(node.LeftExpr)
   231  				}
   232  			}
   233  		case *sqlparser.Select:
   234  			if len(node.From) == 1 {
   235  				_, notJoin := node.From[0].(*sqlparser.AliasedTableExpr)
   236  				if notJoin {
   237  					simplified = false
   238  					return false
   239  				}
   240  			}
   241  			for i, tbl := range node.From {
   242  				lft, ok := tbl.(*sqlparser.AliasedTableExpr)
   243  				if ok {
   244  					ts := semTable.TableSetFor(lft)
   245  					if searchedTS == ts {
   246  						node.From = append(node.From[:i], node.From[i+1:]...)
   247  						return true
   248  					}
   249  				}
   250  			}
   251  		case *sqlparser.Where:
   252  			exprs := sqlparser.SplitAndExpression(nil, node.Expr)
   253  			var newPredicate sqlparser.Expr
   254  			for _, expr := range exprs {
   255  				if !semTable.RecursiveDeps(expr).IsOverlapping(searchedTS) {
   256  					newPredicate = sqlparser.AndExpressions(newPredicate, expr)
   257  				}
   258  			}
   259  			node.Expr = newPredicate
   260  		case sqlparser.SelectExprs:
   261  			_, isSel := cursor.Parent().(*sqlparser.Select)
   262  			if !isSel {
   263  				return true
   264  			}
   265  
   266  			var newExprs sqlparser.SelectExprs
   267  			for _, ae := range node {
   268  				expr, ok := ae.(*sqlparser.AliasedExpr)
   269  				if !ok {
   270  					newExprs = append(newExprs, ae)
   271  					continue
   272  				}
   273  				if shouldKeepExpr(expr.Expr) {
   274  					newExprs = append(newExprs, ae)
   275  				}
   276  			}
   277  			cursor.Replace(newExprs)
   278  		case sqlparser.GroupBy:
   279  			var newExprs sqlparser.GroupBy
   280  			for _, expr := range node {
   281  				if shouldKeepExpr(expr) {
   282  					newExprs = append(newExprs, expr)
   283  				}
   284  			}
   285  			cursor.Replace(newExprs)
   286  		case sqlparser.OrderBy:
   287  			var newExprs sqlparser.OrderBy
   288  			for _, expr := range node {
   289  				if shouldKeepExpr(expr.Expr) {
   290  					newExprs = append(newExprs, expr)
   291  				}
   292  			}
   293  
   294  			cursor.Replace(newExprs)
   295  		}
   296  		return true
   297  	}, nil)
   298  	return simplified
   299  }
   300  
   301  type expressionCursor struct {
   302  	expr    sqlparser.Expr
   303  	replace func(replaceWith sqlparser.Expr)
   304  	remove  func() bool
   305  	restore func()
   306  }
   307  
   308  func newExprCursor(expr sqlparser.Expr, replace func(replaceWith sqlparser.Expr), remove func() bool, restore func()) expressionCursor {
   309  	return expressionCursor{
   310  		expr:    expr,
   311  		replace: replace,
   312  		remove:  remove,
   313  		restore: restore,
   314  	}
   315  }
   316  
   317  // visitAllExpressionsInAST will walk the AST and visit all expressions
   318  // This cursor has a few extra capabilities that the normal sqlparser.Rewrite does not have,
   319  // such as visiting and being able to change individual expressions in a AND tree
   320  func visitAllExpressionsInAST(clone sqlparser.SelectStatement, visit func(expressionCursor) bool) {
   321  	abort := false
   322  	post := func(*sqlparser.Cursor) bool {
   323  		return !abort
   324  	}
   325  	pre := func(cursor *sqlparser.Cursor) bool {
   326  		if abort {
   327  			return true
   328  		}
   329  		switch node := cursor.Node().(type) {
   330  		case sqlparser.SelectExprs:
   331  			_, isSel := cursor.Parent().(*sqlparser.Select)
   332  			if !isSel {
   333  				return true
   334  			}
   335  			for idx := 0; idx < len(node); idx++ {
   336  				ae := node[idx]
   337  				expr, ok := ae.(*sqlparser.AliasedExpr)
   338  				if !ok {
   339  					continue
   340  				}
   341  				removed := false
   342  				original := sqlparser.CloneExpr(expr.Expr)
   343  				item := newExprCursor(
   344  					expr.Expr,
   345  					/*replace*/ func(replaceWith sqlparser.Expr) {
   346  						if removed {
   347  							panic("cant replace after remove without restore")
   348  						}
   349  						expr.Expr = replaceWith
   350  					},
   351  					/*remove*/ func() bool {
   352  						if removed {
   353  							panic("can't remove twice, silly")
   354  						}
   355  						if len(node) == 1 {
   356  							// can't remove the last expressions - we'd end up with an empty SELECT clause
   357  							return false
   358  						}
   359  						withoutElement := append(node[:idx], node[idx+1:]...)
   360  						cursor.Replace(withoutElement)
   361  						node = withoutElement
   362  						removed = true
   363  						return true
   364  					},
   365  					/*restore*/ func() {
   366  						if removed {
   367  							front := make(sqlparser.SelectExprs, idx)
   368  							copy(front, node[:idx])
   369  							back := make(sqlparser.SelectExprs, len(node)-idx)
   370  							copy(back, node[idx:])
   371  							frontWithRestoredExpr := append(front, ae)
   372  							node = append(frontWithRestoredExpr, back...)
   373  							cursor.Replace(node)
   374  							removed = false
   375  							return
   376  						}
   377  						expr.Expr = original
   378  					},
   379  				)
   380  				abort = !visit(item)
   381  			}
   382  		case *sqlparser.Where:
   383  			exprs := sqlparser.SplitAndExpression(nil, node.Expr)
   384  			set := func(input []sqlparser.Expr) {
   385  				node.Expr = sqlparser.AndExpressions(input...)
   386  				exprs = input
   387  			}
   388  			abort = !visitExpressions(exprs, set, visit)
   389  		case *sqlparser.JoinCondition:
   390  			join, ok := cursor.Parent().(*sqlparser.JoinTableExpr)
   391  			if !ok {
   392  				return true
   393  			}
   394  			if join.Join != sqlparser.NormalJoinType || node.Using != nil {
   395  				return false
   396  			}
   397  			exprs := sqlparser.SplitAndExpression(nil, node.On)
   398  			set := func(input []sqlparser.Expr) {
   399  				node.On = sqlparser.AndExpressions(input...)
   400  				exprs = input
   401  			}
   402  			abort = !visitExpressions(exprs, set, visit)
   403  		case sqlparser.GroupBy:
   404  			set := func(input []sqlparser.Expr) {
   405  				node = input
   406  				cursor.Replace(node)
   407  			}
   408  			abort = !visitExpressions(node, set, visit)
   409  		case sqlparser.OrderBy:
   410  			for idx := 0; idx < len(node); idx++ {
   411  				order := node[idx]
   412  				removed := false
   413  				original := sqlparser.CloneExpr(order.Expr)
   414  				item := newExprCursor(
   415  					order.Expr,
   416  					/*replace*/ func(replaceWith sqlparser.Expr) {
   417  						if removed {
   418  							panic("cant replace after remove without restore")
   419  						}
   420  						order.Expr = replaceWith
   421  					},
   422  					/*remove*/ func() bool {
   423  						if removed {
   424  							panic("can't remove twice, silly")
   425  						}
   426  						withoutElement := append(node[:idx], node[idx+1:]...)
   427  						if len(withoutElement) == 0 {
   428  							var nilVal sqlparser.OrderBy // this is used to create a typed nil value
   429  							cursor.Replace(nilVal)
   430  						} else {
   431  							cursor.Replace(withoutElement)
   432  						}
   433  						node = withoutElement
   434  						removed = true
   435  						return true
   436  					},
   437  					/*restore*/ func() {
   438  						if removed {
   439  							front := make(sqlparser.OrderBy, idx)
   440  							copy(front, node[:idx])
   441  							back := make(sqlparser.OrderBy, len(node)-idx)
   442  							copy(back, node[idx:])
   443  							frontWithRestoredExpr := append(front, order)
   444  							node = append(frontWithRestoredExpr, back...)
   445  							cursor.Replace(node)
   446  							removed = false
   447  							return
   448  						}
   449  						order.Expr = original
   450  					},
   451  				)
   452  				abort = visit(item)
   453  				if abort {
   454  					break
   455  				}
   456  			}
   457  		case *sqlparser.Limit:
   458  			if node.Offset != nil {
   459  				original := node.Offset
   460  				cursor := newExprCursor(node.Offset,
   461  					/*replace*/ func(replaceWith sqlparser.Expr) {
   462  						node.Offset = replaceWith
   463  					},
   464  					/*remove*/ func() bool {
   465  						node.Offset = nil
   466  						return true
   467  					},
   468  					/*restore*/ func() {
   469  						node.Offset = original
   470  					})
   471  				abort = visit(cursor)
   472  			}
   473  			if !abort && node.Rowcount != nil {
   474  				original := node.Rowcount
   475  				cursor := newExprCursor(node.Rowcount,
   476  					/*replace*/ func(replaceWith sqlparser.Expr) {
   477  						node.Rowcount = replaceWith
   478  					},
   479  					/*remove*/ func() bool {
   480  						// removing Rowcount is an invalid op
   481  						return false
   482  					},
   483  					/*restore*/ func() {
   484  						node.Rowcount = original
   485  					})
   486  				abort = visit(cursor)
   487  			}
   488  		}
   489  		return true
   490  	}
   491  	sqlparser.Rewrite(clone, pre, post)
   492  }
   493  
   494  // visitExpressions allows the cursor to visit all expressions in a slice,
   495  // and can replace or remove items and restore the slice.
   496  func visitExpressions(
   497  	exprs []sqlparser.Expr,
   498  	set func(input []sqlparser.Expr),
   499  	visit func(expressionCursor) bool,
   500  ) bool {
   501  	for idx := 0; idx < len(exprs); idx++ {
   502  		expr := exprs[idx]
   503  		removed := false
   504  		item := newExprCursor(expr,
   505  			func(replaceWith sqlparser.Expr) {
   506  				if removed {
   507  					panic("cant replace after remove without restore")
   508  				}
   509  				exprs[idx] = replaceWith
   510  				set(exprs)
   511  			},
   512  			/*remove*/ func() bool {
   513  				if removed {
   514  					panic("can't remove twice, silly")
   515  				}
   516  				exprs = append(exprs[:idx], exprs[idx+1:]...)
   517  				set(exprs)
   518  				removed = true
   519  				return true
   520  			},
   521  			/*restore*/ func() {
   522  				if removed {
   523  					front := make([]sqlparser.Expr, idx)
   524  					copy(front, exprs[:idx])
   525  					back := make([]sqlparser.Expr, len(exprs)-idx)
   526  					copy(back, exprs[idx:])
   527  					frontWithRestoredExpr := append(front, expr)
   528  					exprs = append(frontWithRestoredExpr, back...)
   529  					set(exprs)
   530  					removed = false
   531  					return
   532  				}
   533  				exprs[idx] = expr
   534  				set(exprs)
   535  			})
   536  		if !visit(item) {
   537  			return false
   538  		}
   539  	}
   540  	return true
   541  }