github.com/dolthub/go-mysql-server@v0.18.0/sql/analyzer/optimization_rules.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package analyzer
    16  
    17  import (
    18  	"strings"
    19  
    20  	"github.com/dolthub/go-mysql-server/sql"
    21  	"github.com/dolthub/go-mysql-server/sql/expression"
    22  	"github.com/dolthub/go-mysql-server/sql/plan"
    23  	"github.com/dolthub/go-mysql-server/sql/transform"
    24  	"github.com/dolthub/go-mysql-server/sql/types"
    25  )
    26  
    27  // eraseProjection removes redundant Project nodes from the plan. A project
    28  // is redundant if it doesn't alter the schema of its child. Special
    29  // considerations: (1) target projections casing needs
    30  // to be preserved in the output schema even if the projection is redundant;
    31  // (2) column ids are not reliable enough to maximally prune projections,
    32  // we still need to check column/table/database names.
    33  // todo: analyzer should separate target schema from plan schema
    34  // todo: projection columns should all have ids so that pruning is more reliable
    35  func eraseProjection(ctx *sql.Context, a *Analyzer, node sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) {
    36  	span, ctx := ctx.Span("erase_projection")
    37  	defer span.End()
    38  
    39  	if !node.Resolved() {
    40  		return node, transform.SameTree, nil
    41  	}
    42  
    43  	return transform.Node(node, func(node sql.Node) (sql.Node, transform.TreeIdentity, error) {
    44  		project, ok := node.(*plan.Project)
    45  		if ok {
    46  			if project.Schema().CaseSensitiveEquals(project.Child.Schema()) {
    47  				a.Log("project erased")
    48  				return project.Child, transform.NewTree, nil
    49  			}
    50  
    51  		}
    52  
    53  		return node, transform.SameTree, nil
    54  	})
    55  }
    56  
    57  func flattenDistinct(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) {
    58  	return transform.Node(n, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
    59  		if d, ok := n.(*plan.Distinct); ok {
    60  			if d2, ok := d.Child.(*plan.Distinct); ok {
    61  				return d2, transform.NewTree, nil
    62  			}
    63  			if d2, ok := d.Child.(*plan.OrderedDistinct); ok {
    64  				return d2, transform.NewTree, nil
    65  			}
    66  		}
    67  		if d, ok := n.(*plan.OrderedDistinct); ok {
    68  			if d2, ok := d.Child.(*plan.Distinct); ok {
    69  				return plan.NewOrderedDistinct(d2.Child), transform.NewTree, nil
    70  			}
    71  			if d2, ok := d.Child.(*plan.OrderedDistinct); ok {
    72  				return d2, transform.NewTree, nil
    73  			}
    74  		}
    75  		return n, transform.SameTree, nil
    76  	})
    77  }
    78  
    79  // moveJoinConditionsToFilter looks for expressions in a join condition that reference only tables in the left or right
    80  // side of the join, and move those conditions to a new Filter node instead. If the join condition is empty after these
    81  // moves, the join is converted to a CrossJoin.
    82  func moveJoinConditionsToFilter(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) {
    83  	if !n.Resolved() {
    84  		return n, transform.SameTree, nil
    85  	}
    86  
    87  	return transform.Node(n, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
    88  		var rightOnlyFilters []sql.Expression
    89  		var leftOnlyFilters []sql.Expression
    90  
    91  		join, ok := n.(*plan.JoinNode)
    92  		if !ok {
    93  			// no join
    94  			return n, transform.SameTree, nil
    95  		}
    96  
    97  		// no filter or left join: nothing to do to the tree
    98  		if join.JoinType().IsDegenerate() {
    99  			return n, transform.SameTree, nil
   100  		}
   101  		if !(join.JoinType().IsInner() || join.JoinType().IsSemi()) {
   102  			return n, transform.SameTree, nil
   103  		}
   104  		leftSources := nodeSources(join.Left())
   105  		rightSources := nodeSources(join.Right())
   106  		filtersMoved := 0
   107  		var condFilters []sql.Expression
   108  		for _, e := range expression.SplitConjunction(join.JoinCond()) {
   109  			sources, nullRej := expressionSources(e)
   110  			if !nullRej {
   111  				condFilters = append(condFilters, e)
   112  				continue
   113  			}
   114  
   115  			if sources.SubsetOf(leftSources) {
   116  				leftOnlyFilters = append(leftOnlyFilters, e)
   117  				filtersMoved++
   118  			} else if sources.SubsetOf(rightSources) {
   119  				rightOnlyFilters = append(rightOnlyFilters, e)
   120  				filtersMoved++
   121  			} else {
   122  				condFilters = append(condFilters, e)
   123  			}
   124  		}
   125  
   126  		if filtersMoved == 0 {
   127  			return n, transform.SameTree, nil
   128  		}
   129  
   130  		newLeft := join.Left()
   131  		if len(leftOnlyFilters) > 0 {
   132  			newLeft = plan.NewFilter(expression.JoinAnd(leftOnlyFilters...), newLeft)
   133  		}
   134  
   135  		newRight := join.Right()
   136  		if len(rightOnlyFilters) > 0 {
   137  			newRight = plan.NewFilter(expression.JoinAnd(rightOnlyFilters...), newRight)
   138  		}
   139  
   140  		if len(condFilters) == 0 {
   141  			condFilters = append(condFilters, expression.NewLiteral(true, types.Boolean))
   142  		}
   143  
   144  		return plan.NewJoin(newLeft, newRight, join.Op, expression.JoinAnd(condFilters...)).WithComment(join.CommentStr), transform.NewTree, nil
   145  	})
   146  }
   147  
   148  // containsSources checks that all `needle` sources are contained inside `haystack`.
   149  func containsSources(haystack, needle []sql.TableId) bool {
   150  	for _, s := range needle {
   151  		var found bool
   152  		for _, s2 := range haystack {
   153  			if s2 == s {
   154  				found = true
   155  				break
   156  			}
   157  		}
   158  
   159  		if !found {
   160  			return false
   161  		}
   162  	}
   163  
   164  	return true
   165  }
   166  
   167  // nodeSources returns the set of column sources from the schema of the node given.
   168  func nodeSources(n sql.Node) sql.FastIntSet {
   169  	var tables sql.FastIntSet
   170  	transform.InspectUp(n, func(n sql.Node) bool {
   171  		tin, _ := n.(plan.TableIdNode)
   172  		if tin != nil {
   173  			tables.Add(int(tin.Id()))
   174  		}
   175  		return false
   176  	})
   177  	return tables
   178  }
   179  
   180  // expressionSources returns the set of sources from any GetField expressions
   181  // in the expression given, and a boolean indicating whether the expression
   182  // is null rejecting from those sources.
   183  func expressionSources(expr sql.Expression) (sql.FastIntSet, bool) {
   184  	var tables sql.FastIntSet
   185  	var nullRejecting bool = true
   186  
   187  	sql.Inspect(expr, func(e sql.Expression) bool {
   188  		switch e := e.(type) {
   189  		case *expression.GetField:
   190  			tables.Add(int(e.TableId()))
   191  		case *expression.IsNull:
   192  			nullRejecting = false
   193  		case *expression.NullSafeEquals:
   194  			nullRejecting = false
   195  		case *expression.Equals:
   196  			if lit, ok := e.Left().(*expression.Literal); ok && lit.Value() == nil {
   197  				nullRejecting = false
   198  			}
   199  			if lit, ok := e.Right().(*expression.Literal); ok && lit.Value() == nil {
   200  				nullRejecting = false
   201  			}
   202  		case *plan.Subquery:
   203  			transform.InspectExpressions(e.Query, func(innerExpr sql.Expression) bool {
   204  				switch e := innerExpr.(type) {
   205  				case *expression.GetField:
   206  					tables.Add(int(e.TableId()))
   207  				case *expression.IsNull:
   208  					nullRejecting = false
   209  				case *expression.NullSafeEquals:
   210  					nullRejecting = false
   211  				case *expression.Equals:
   212  					if lit, ok := e.Left().(*expression.Literal); ok && lit.Value() == nil {
   213  						nullRejecting = false
   214  					}
   215  					if lit, ok := e.Right().(*expression.Literal); ok && lit.Value() == nil {
   216  						nullRejecting = false
   217  					}
   218  				}
   219  				return true
   220  			})
   221  		}
   222  		return true
   223  	})
   224  
   225  	return tables, nullRejecting
   226  }
   227  
   228  // simplifyFilters simplifies the expressions in Filter nodes where possible. This involves removing redundant parts of AND
   229  // and OR expressions, as well as replacing evaluable expressions with their literal result. Filters that can
   230  // statically be determined to be true or false are replaced with the child node or an empty result, respectively.
   231  func simplifyFilters(ctx *sql.Context, a *Analyzer, node sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) {
   232  	if !node.Resolved() {
   233  		return node, transform.SameTree, nil
   234  	}
   235  
   236  	return transform.NodeWithOpaque(node, func(node sql.Node) (sql.Node, transform.TreeIdentity, error) {
   237  		filter, ok := node.(*plan.Filter)
   238  		if !ok {
   239  			return node, transform.SameTree, nil
   240  		}
   241  
   242  		e, same, err := transform.Expr(filter.Expression, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
   243  			switch e := e.(type) {
   244  			case *plan.Subquery:
   245  				newQ, same, err := simplifyFilters(ctx, a, e.Query, scope, sel)
   246  				if same || err != nil {
   247  					return e, transform.SameTree, err
   248  				}
   249  				return e.WithQuery(newQ), transform.NewTree, nil
   250  			case *expression.Between:
   251  				return expression.NewAnd(
   252  					expression.NewGreaterThanOrEqual(e.Val, e.Lower),
   253  					expression.NewLessThanOrEqual(e.Val, e.Upper),
   254  				), transform.NewTree, nil
   255  			case *expression.Or:
   256  				if isTrue(e.LeftChild) {
   257  					return e.LeftChild, transform.NewTree, nil
   258  				}
   259  
   260  				if isTrue(e.RightChild) {
   261  					return e.RightChild, transform.NewTree, nil
   262  				}
   263  
   264  				if isFalse(e.LeftChild) {
   265  					return e.RightChild, transform.NewTree, nil
   266  				}
   267  
   268  				if isFalse(e.RightChild) {
   269  					return e.LeftChild, transform.NewTree, nil
   270  				}
   271  
   272  				return e, transform.SameTree, nil
   273  			case *expression.And:
   274  				if isFalse(e.LeftChild) {
   275  					return e.LeftChild, transform.NewTree, nil
   276  				}
   277  
   278  				if isFalse(e.RightChild) {
   279  					return e.RightChild, transform.NewTree, nil
   280  				}
   281  
   282  				if isTrue(e.LeftChild) {
   283  					return e.RightChild, transform.NewTree, nil
   284  				}
   285  
   286  				if isTrue(e.RightChild) {
   287  					return e.LeftChild, transform.NewTree, nil
   288  				}
   289  
   290  				return e, transform.SameTree, nil
   291  			case *expression.Like:
   292  				// if the charset is not utf8mb4, the last character used in optimization rule does not work
   293  				coll, _ := sql.GetCoercibility(ctx, e.LeftChild)
   294  				charset := coll.CharacterSet()
   295  				if charset != sql.CharacterSet_utf8mb4 {
   296  					return e, transform.SameTree, nil
   297  				}
   298  				// TODO: maybe more cases to simplify
   299  				r, ok := e.RightChild.(*expression.Literal)
   300  				if !ok {
   301  					return e, transform.SameTree, nil
   302  				}
   303  				// TODO: handle escapes
   304  				if e.Escape != nil {
   305  					return e, transform.SameTree, nil
   306  				}
   307  				val := r.Value()
   308  				valStr, ok := val.(string)
   309  				if !ok {
   310  					return e, transform.SameTree, nil
   311  				}
   312  				if len(valStr) == 0 {
   313  					return e, transform.SameTree, nil
   314  				}
   315  				// if there are single character wildcards, don't simplify
   316  				if strings.Count(valStr, "_")-strings.Count(valStr, "\\_") > 0 {
   317  					return e, transform.SameTree, nil
   318  				}
   319  				// if there are also no multiple character wildcards, this is just a plain equals
   320  				numWild := strings.Count(valStr, "%") - strings.Count(valStr, "\\%")
   321  				if numWild == 0 {
   322  					return expression.NewEquals(e.LeftChild, e.RightChild), transform.NewTree, nil
   323  				}
   324  				// if there are many multiple character wildcards, don't simplify
   325  				if numWild != 1 {
   326  					return e, transform.SameTree, nil
   327  				}
   328  				// if the last character is an escaped multiple character wildcard, don't simplify
   329  				if len(valStr) >= 2 && valStr[len(valStr)-2:] == "\\%" {
   330  					return e, transform.SameTree, nil
   331  				}
   332  				if valStr[len(valStr)-1] != '%' {
   333  					return e, transform.SameTree, nil
   334  				}
   335  				// TODO: like expression with just a wild card shouldn't even make it here; analyzer rule should just drop filter
   336  				if len(valStr) == 1 {
   337  					return e, transform.SameTree, nil
   338  				}
   339  				valStr = valStr[:len(valStr)-1]
   340  				newRightLower := expression.NewLiteral(valStr, e.RightChild.Type())
   341  				valStr += string(byte(255)) // append largest possible character as upper bound
   342  				newRightUpper := expression.NewLiteral(valStr, e.RightChild.Type())
   343  				newExpr := expression.NewAnd(expression.NewGreaterThanOrEqual(e.LeftChild, newRightLower), expression.NewLessThanOrEqual(e.LeftChild, newRightUpper))
   344  				return newExpr, transform.NewTree, nil
   345  			case *expression.Literal, expression.Tuple, *expression.Interval, *expression.CollatedExpression, *expression.MatchAgainst:
   346  				return e, transform.SameTree, nil
   347  			default:
   348  				if !isEvaluable(e) {
   349  					return e, transform.SameTree, nil
   350  				}
   351  
   352  				// All other expressions types can be evaluated once and turned into literals for the rest of query execution
   353  				val, err := e.Eval(ctx, nil)
   354  				if err != nil {
   355  					return e, transform.SameTree, nil
   356  				}
   357  				return expression.NewLiteral(val, e.Type()), transform.NewTree, nil
   358  			}
   359  		})
   360  		if err != nil {
   361  			return nil, transform.SameTree, err
   362  		}
   363  
   364  		if isFalse(e) {
   365  			emptyTable := plan.NewEmptyTableWithSchema(filter.Schema())
   366  			return emptyTable, transform.NewTree, nil
   367  		}
   368  
   369  		if isTrue(e) {
   370  			return filter.Child, transform.NewTree, nil
   371  		}
   372  
   373  		if same {
   374  			return filter, transform.SameTree, nil
   375  		}
   376  		return plan.NewFilter(e, filter.Child), transform.NewTree, nil
   377  	})
   378  }
   379  
   380  func isFalse(e sql.Expression) bool {
   381  	lit, ok := e.(*expression.Literal)
   382  	if ok && lit != nil && lit.Type() == types.Boolean && lit.Value() != nil {
   383  		switch v := lit.Value().(type) {
   384  		case bool:
   385  			return !v
   386  		case int8:
   387  			return v == sql.False
   388  		}
   389  	}
   390  	return false
   391  }
   392  
   393  func isTrue(e sql.Expression) bool {
   394  	lit, ok := e.(*expression.Literal)
   395  	if ok && lit != nil && lit.Type() == types.Boolean && lit.Value() != nil {
   396  		switch v := lit.Value().(type) {
   397  		case bool:
   398  			return v
   399  		case int8:
   400  			return v != sql.False
   401  		}
   402  	}
   403  	return false
   404  }
   405  
   406  // pushNotFilters applies De'Morgan's laws to push NOT expressions as low
   407  // in expression trees as possible and inverts NOT leaf expressions.
   408  // ref: https://en.wikipedia.org/wiki/De_Morgan%27s_laws
   409  // note: the output tree identity will not be accurate
   410  func pushNotFilters(_ *sql.Context, _ *Analyzer, n sql.Node, _ *plan.Scope, _ RuleSelector) (sql.Node, transform.TreeIdentity, error) {
   411  	return transform.Node(n, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
   412  		var e sql.Expression
   413  		var err error
   414  		switch n := n.(type) {
   415  		case *plan.Filter:
   416  			e, err = pushNotFiltersHelper(n.Expression)
   417  		case *plan.JoinNode:
   418  			if n.Filter != nil {
   419  				e, err = pushNotFiltersHelper(n.Filter)
   420  			}
   421  		default:
   422  			return n, transform.SameTree, nil
   423  		}
   424  		if err != nil {
   425  			return n, transform.SameTree, nil
   426  		}
   427  		ret, err := n.(sql.Expressioner).WithExpressions(e)
   428  		if err != nil {
   429  			return n, transform.SameTree, nil
   430  		}
   431  		return ret, transform.NewTree, nil
   432  	})
   433  }
   434  
   435  // TODO maybe: NOT(INTUPLE(c...)), NOT(EQ(c))=>OR(LT(c), GT(c))
   436  func pushNotFiltersHelper(e sql.Expression) (sql.Expression, error) {
   437  	// NOT(NOT(c))=>c
   438  	if not, _ := e.(*expression.Not); not != nil {
   439  		if f, _ := not.Child.(*expression.Not); f != nil {
   440  			return pushNotFiltersHelper(f.Child)
   441  		}
   442  	}
   443  
   444  	// NOT(AND(left,right))=>OR(NOT(left), NOT(right))
   445  	if not, _ := e.(*expression.Not); not != nil {
   446  		if f, _ := not.Child.(*expression.And); f != nil {
   447  			return pushNotFiltersHelper(expression.NewOr(expression.NewNot(f.LeftChild), expression.NewNot(f.RightChild)))
   448  		}
   449  	}
   450  
   451  	// NOT(OR(left,right))=>AND(NOT(left), NOT(right))
   452  	if not, _ := e.(*expression.Not); not != nil {
   453  		if f, _ := not.Child.(*expression.Or); f != nil {
   454  			return pushNotFiltersHelper(expression.NewAnd(expression.NewNot(f.LeftChild), expression.NewNot(f.RightChild)))
   455  		}
   456  	}
   457  
   458  	// NOT(GT(c))=>LTE(c)
   459  	if not, _ := e.(*expression.Not); not != nil {
   460  		if f, _ := not.Child.(*expression.GreaterThan); f != nil {
   461  			return pushNotFiltersHelper(expression.NewLessThanOrEqual(f.Left(), f.Right()))
   462  		}
   463  	}
   464  
   465  	// NOT(GTE(c))=>LT(c)
   466  	if not, _ := e.(*expression.Not); not != nil {
   467  		if f, _ := not.Child.(*expression.GreaterThanOrEqual); f != nil {
   468  			return pushNotFiltersHelper(expression.NewLessThan(f.Left(), f.Right()))
   469  		}
   470  	}
   471  
   472  	// NOT(LT(c))=>GTE(c)
   473  	if not, _ := e.(*expression.Not); not != nil {
   474  		if f, _ := not.Child.(*expression.LessThan); f != nil {
   475  			return pushNotFiltersHelper(expression.NewGreaterThanOrEqual(f.Left(), f.Right()))
   476  		}
   477  	}
   478  
   479  	// NOT(LTE(c))=>GT(c)
   480  	if not, _ := e.(*expression.Not); not != nil {
   481  		if f, _ := not.Child.(*expression.LessThanOrEqual); f != nil {
   482  			return pushNotFiltersHelper(expression.NewGreaterThan(f.Left(), f.Right()))
   483  		}
   484  	}
   485  
   486  	//NOT(BETWEEN(left,right))=>OR(LT(left), GT(right))
   487  	if not, _ := e.(*expression.Not); not != nil {
   488  		if f, _ := not.Child.(*expression.Between); f != nil {
   489  			return pushNotFiltersHelper(expression.NewOr(
   490  				expression.NewLessThan(f.Val, f.Lower),
   491  				expression.NewGreaterThan(f.Val, f.Upper),
   492  			))
   493  		}
   494  	}
   495  
   496  	var newChildren []sql.Expression
   497  	for _, c := range e.Children() {
   498  		newC, err := pushNotFiltersHelper(c)
   499  		if err != nil {
   500  			return nil, err
   501  		}
   502  		newChildren = append(newChildren, newC)
   503  	}
   504  	return e.WithChildren(newChildren...)
   505  }