github.com/dolthub/go-mysql-server@v0.18.0/sql/analyzer/replace_cross_joins.go (about) 1 // Copyright 2020-2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package analyzer 16 17 import ( 18 "github.com/dolthub/go-mysql-server/sql" 19 "github.com/dolthub/go-mysql-server/sql/expression" 20 "github.com/dolthub/go-mysql-server/sql/plan" 21 "github.com/dolthub/go-mysql-server/sql/transform" 22 ) 23 24 // comparisonSatisfiesJoinCondition checks a) whether a comparison is a valid join predicate, 25 // and b) whether the Left/Right children of a comparison expression covers the dependency trees 26 // of a plan.CrossJoin's children. 27 func comparisonSatisfiesJoinCondition(expr expression.Comparer, j *plan.JoinNode) bool { 28 lCols := j.Left().Schema() 29 rCols := j.Right().Schema() 30 31 var re, le *expression.GetField 32 switch e := expr.(type) { 33 case *expression.Equals, *expression.NullSafeEquals, *expression.GreaterThan, 34 *expression.LessThan, *expression.LessThanOrEqual, *expression.GreaterThanOrEqual: 35 36 ce, ok := e.(expression.Comparer) 37 if !ok { 38 return false 39 } 40 le, ok = ce.Left().(*expression.GetField) 41 if !ok { 42 return false 43 } 44 re, ok = ce.Right().(*expression.GetField) 45 if !ok { 46 return false 47 } 48 default: 49 return false 50 } 51 52 return lCols.Contains(le.Name(), le.Table()) && rCols.Contains(re.Name(), re.Table()) || 53 rCols.Contains(le.Name(), le.Table()) && lCols.Contains(re.Name(), re.Table()) 54 } 55 56 // expressionCoversJoin checks whether a subexpressions's comparison predicate 57 // satisfies the join condition. The input conjunctions have already been split, 58 // so we do not care which predicate satisfies the expression. 59 func expressionCoversJoin(c sql.Expression, j *plan.JoinNode) (found bool) { 60 return transform.InspectExpr(c, func(expr sql.Expression) bool { 61 switch e := expr.(type) { 62 case expression.Comparer: 63 return comparisonSatisfiesJoinCondition(e, j) 64 } 65 return false 66 }) 67 } 68 69 // replaceCrossJoins recursively replaces filter nested cross joins with equivalent inner joins. 70 // There are 3 phases after we identify a Filter -> ... -> CrossJoin pattern. 71 // 1. Build a list of disjunct predicate expressions by top-down splitting conjunctions (AND). 72 // 2. For every CrossJoin, check whether a subset of predicates covers as join conditions, 73 // and create a new InnerJoin with the matching predicates. 74 // 3. Remove predicates from the parent Filter that have been pushed into InnerJoins. 75 func replaceCrossJoins(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) { 76 if !n.Resolved() { 77 return n, transform.SameTree, nil 78 } 79 80 return transform.Node(n, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) { 81 f, ok := n.(*plan.Filter) 82 if !ok { 83 return n, transform.SameTree, nil 84 } 85 predicates := expression.SplitConjunction(f.Expression) 86 movedPredicates := make(map[int]struct{}) 87 newF, _, err := transform.Node(f, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) { 88 cj, ok := n.(*plan.JoinNode) 89 if !ok || !cj.Op.IsCross() { 90 return n, transform.SameTree, nil 91 } 92 93 joinConjs := make([]int, 0, len(predicates)) 94 for i, c := range predicates { 95 if expressionCoversJoin(c, cj) { 96 joinConjs = append(joinConjs, i) 97 } 98 } 99 100 if len(joinConjs) == 0 { 101 return n, transform.SameTree, nil 102 } 103 104 newExprs := make([]sql.Expression, len(joinConjs)) 105 for i, v := range joinConjs { 106 movedPredicates[v] = struct{}{} 107 newExprs[i] = predicates[v] 108 } 109 // retain comment 110 nij := plan.NewInnerJoin(cj.Left(), cj.Right(), expression.JoinAnd(newExprs...)) 111 return nij.WithComment(cj.Comment()), transform.NewTree, nil 112 }) 113 if err != nil { 114 return f, transform.SameTree, err 115 } 116 117 // only alter the Filter expression tree if we transferred predicates to an InnerJoin 118 if len(movedPredicates) == 0 { 119 return f, transform.SameTree, nil 120 } 121 122 // remove Filter if all expressions were transferred to joins 123 if len(predicates) == len(movedPredicates) { 124 return newF.(*plan.Filter).Child, transform.NewTree, nil 125 } 126 127 newFilterExprs := make([]sql.Expression, 0, len(predicates)-len(movedPredicates)) 128 for i, e := range predicates { 129 if _, ok := movedPredicates[i]; ok { 130 continue 131 } 132 newFilterExprs = append(newFilterExprs, e) 133 } 134 newF, err = newF.(*plan.Filter).WithExpressions(expression.JoinAnd(newFilterExprs...)) 135 return newF, transform.NewTree, err 136 }) 137 }