vitess.io/vitess@v0.16.2/go/vt/vtgate/planbuilder/rewrite.go (about) 1 /* 2 Copyright 2021 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package planbuilder 18 19 import ( 20 "vitess.io/vitess/go/vt/sqlparser" 21 "vitess.io/vitess/go/vt/vterrors" 22 "vitess.io/vitess/go/vt/vtgate/engine" 23 "vitess.io/vitess/go/vt/vtgate/semantics" 24 ) 25 26 type rewriter struct { 27 semTable *semantics.SemTable 28 reservedVars *sqlparser.ReservedVars 29 inSubquery int 30 err error 31 } 32 33 func queryRewrite(semTable *semantics.SemTable, reservedVars *sqlparser.ReservedVars, statement sqlparser.Statement) error { 34 r := rewriter{ 35 semTable: semTable, 36 reservedVars: reservedVars, 37 } 38 sqlparser.Rewrite(statement, r.rewriteDown, r.rewriteUp) 39 return nil 40 } 41 42 func (r *rewriter) rewriteDown(cursor *sqlparser.Cursor) bool { 43 switch node := cursor.Node().(type) { 44 case *sqlparser.Select: 45 rewriteHavingClause(node) 46 case *sqlparser.ComparisonExpr: 47 err := rewriteInSubquery(cursor, r, node) 48 if err != nil { 49 r.err = err 50 } 51 case *sqlparser.ExistsExpr: 52 err := r.rewriteExistsSubquery(cursor, node) 53 if err != nil { 54 r.err = err 55 } 56 return false 57 case *sqlparser.AliasedTableExpr: 58 // rewrite names of the routed tables for the subquery 59 // We only need to do this for non-derived tables and if they are in a subquery 60 if _, isDerived := node.Expr.(*sqlparser.DerivedTable); isDerived || r.inSubquery == 0 { 61 break 62 } 63 // find the tableSet and tableInfo that this table points to 64 // tableInfo should contain the information for the original table that the routed table points to 65 tableSet := r.semTable.TableSetFor(node) 66 tableInfo, err := r.semTable.TableInfoFor(tableSet) 67 if err != nil { 68 // Fail-safe code, should never happen 69 break 70 } 71 // vindexTable is the original table 72 vindexTable := tableInfo.GetVindexTable() 73 if vindexTable == nil { 74 break 75 } 76 tableName := node.Expr.(sqlparser.TableName) 77 // if the table name matches what the original is, then we do not need to rewrite 78 if sqlparser.Equals.IdentifierCS(vindexTable.Name, tableName.Name) { 79 break 80 } 81 // if there is no as clause, then move the routed table to the as clause. 82 // i.e 83 // routed as x -> original as x 84 // routed -> original as routed 85 if node.As.IsEmpty() { 86 node.As = tableName.Name 87 } 88 // replace the table name with the original table 89 tableName.Name = vindexTable.Name 90 node.Expr = tableName 91 case *sqlparser.Subquery: 92 err := rewriteSubquery(cursor, r, node) 93 if err != nil { 94 r.err = err 95 } 96 } 97 return true 98 } 99 100 func (r *rewriter) rewriteUp(cursor *sqlparser.Cursor) bool { 101 switch cursor.Node().(type) { 102 case *sqlparser.Subquery: 103 r.inSubquery-- 104 } 105 return r.err == nil 106 } 107 108 func rewriteInSubquery(cursor *sqlparser.Cursor, r *rewriter, node *sqlparser.ComparisonExpr) error { 109 subq, exp := semantics.GetSubqueryAndOtherSide(node) 110 if subq == nil || exp == nil { 111 return nil 112 } 113 114 semTableSQ, found := r.semTable.SubqueryRef[subq] 115 if !found { 116 return vterrors.VT13001("got subquery that was not in the subq map") 117 } 118 119 r.inSubquery++ 120 argName, hasValuesArg := r.reservedVars.ReserveSubQueryWithHasValues() 121 semTableSQ.SetArgName(argName) 122 semTableSQ.SetHasValuesArg(hasValuesArg) 123 cursor.Replace(semTableSQ) 124 return nil 125 } 126 127 func rewriteSubquery(cursor *sqlparser.Cursor, r *rewriter, node *sqlparser.Subquery) error { 128 semTableSQ, found := r.semTable.SubqueryRef[node] 129 if !found { 130 return vterrors.VT13001("got subquery that was not in the subq map") 131 } 132 if semTableSQ.GetArgName() != "" || engine.PulloutOpcode(semTableSQ.OpCode) != engine.PulloutValue { 133 return nil 134 } 135 r.inSubquery++ 136 argName := r.reservedVars.ReserveSubQuery() 137 semTableSQ.SetArgName(argName) 138 cursor.Replace(semTableSQ) 139 return nil 140 } 141 142 func (r *rewriter) rewriteExistsSubquery(cursor *sqlparser.Cursor, node *sqlparser.ExistsExpr) error { 143 semTableSQ, found := r.semTable.SubqueryRef[node.Subquery] 144 if !found { 145 return vterrors.VT13001("got subquery that was not in the subq map") 146 } 147 148 r.inSubquery++ 149 hasValuesArg := r.reservedVars.ReserveHasValuesSubQuery() 150 semTableSQ.SetHasValuesArg(hasValuesArg) 151 cursor.Replace(semTableSQ) 152 return nil 153 } 154 155 func rewriteHavingClause(node *sqlparser.Select) { 156 if node.Having == nil { 157 return 158 } 159 160 selectExprMap := map[string]sqlparser.Expr{} 161 for _, selectExpr := range node.SelectExprs { 162 aliasedExpr, isAliased := selectExpr.(*sqlparser.AliasedExpr) 163 if !isAliased || aliasedExpr.As.IsEmpty() { 164 continue 165 } 166 selectExprMap[aliasedExpr.As.Lowered()] = aliasedExpr.Expr 167 } 168 169 // for each expression in the having clause, we check if it contains aggregation. 170 // if it does, we keep the expression in the having clause ; and if it does not 171 // and the expression is in the select list, we replace the expression by the one 172 // used in the select list and add it to the where clause instead of the having clause. 173 exprs := sqlparser.SplitAndExpression(nil, node.Having.Expr) 174 node.Having = nil 175 for _, expr := range exprs { 176 hasAggr := sqlparser.ContainsAggregation(expr) 177 if !hasAggr { 178 sqlparser.Rewrite(expr, func(cursor *sqlparser.Cursor) bool { 179 visitColName(cursor.Node(), selectExprMap, func(original sqlparser.Expr) { 180 if sqlparser.ContainsAggregation(original) { 181 hasAggr = true 182 } 183 }) 184 return true 185 }, nil) 186 } 187 if hasAggr { 188 node.AddHaving(expr) 189 } else { 190 sqlparser.Rewrite(expr, func(cursor *sqlparser.Cursor) bool { 191 visitColName(cursor.Node(), selectExprMap, func(original sqlparser.Expr) { 192 cursor.Replace(original) 193 }) 194 return true 195 }, nil) 196 node.AddWhere(expr) 197 } 198 } 199 } 200 func visitColName(cursor sqlparser.SQLNode, selectExprMap map[string]sqlparser.Expr, f func(original sqlparser.Expr)) { 201 switch x := cursor.(type) { 202 case *sqlparser.ColName: 203 if !x.Qualifier.IsEmpty() { 204 return 205 } 206 originalExpr, isInMap := selectExprMap[x.Name.Lowered()] 207 if isInMap { 208 f(originalExpr) 209 } 210 return 211 } 212 }