vitess.io/vitess@v0.16.2/go/vt/vtgate/semantics/early_rewriter.go (about) 1 /* 2 Copyright 2021 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package semantics 18 19 import ( 20 "strconv" 21 "strings" 22 23 "vitess.io/vitess/go/vt/vtgate/evalengine" 24 25 vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" 26 "vitess.io/vitess/go/vt/sqlparser" 27 "vitess.io/vitess/go/vt/vterrors" 28 ) 29 30 type earlyRewriter struct { 31 binder *binder 32 scoper *scoper 33 clause string 34 warning string 35 expandedColumns map[sqlparser.TableName][]*sqlparser.ColName 36 } 37 38 func (r *earlyRewriter) down(cursor *sqlparser.Cursor) error { 39 switch node := cursor.Node().(type) { 40 case *sqlparser.Where: 41 if node.Type != sqlparser.HavingClause { 42 return nil 43 } 44 rewriteHavingAndOrderBy(node, cursor.Parent()) 45 case sqlparser.SelectExprs: 46 _, isSel := cursor.Parent().(*sqlparser.Select) 47 if !isSel { 48 return nil 49 } 50 err := r.expandStar(cursor, node) 51 if err != nil { 52 return err 53 } 54 case *sqlparser.JoinTableExpr: 55 if node.Join == sqlparser.StraightJoinType { 56 node.Join = sqlparser.NormalJoinType 57 r.warning = "straight join is converted to normal join" 58 } 59 case sqlparser.OrderBy: 60 r.clause = "order clause" 61 rewriteHavingAndOrderBy(node, cursor.Parent()) 62 case *sqlparser.OrExpr: 63 newNode := rewriteOrFalse(*node) 64 if newNode != nil { 65 cursor.Replace(newNode) 66 } 67 case sqlparser.GroupBy: 68 r.clause = "group statement" 69 70 case *sqlparser.Literal: 71 newNode, err := r.rewriteOrderByExpr(node) 72 if err != nil { 73 return err 74 } 75 if newNode != nil { 76 cursor.Replace(newNode) 77 } 78 case *sqlparser.CollateExpr: 79 lit, ok := node.Expr.(*sqlparser.Literal) 80 if !ok { 81 return nil 82 } 83 newNode, err := r.rewriteOrderByExpr(lit) 84 if err != nil { 85 return err 86 } 87 if newNode != nil { 88 node.Expr = newNode 89 } 90 case *sqlparser.ComparisonExpr: 91 lft, lftOK := node.Left.(sqlparser.ValTuple) 92 rgt, rgtOK := node.Right.(sqlparser.ValTuple) 93 if !lftOK || !rgtOK || len(lft) != len(rgt) || node.Operator != sqlparser.EqualOp { 94 return nil 95 } 96 var predicates []sqlparser.Expr 97 for i, l := range lft { 98 r := rgt[i] 99 predicates = append(predicates, &sqlparser.ComparisonExpr{ 100 Operator: sqlparser.EqualOp, 101 Left: l, 102 Right: r, 103 Escape: node.Escape, 104 }) 105 } 106 cursor.Replace(sqlparser.AndExpressions(predicates...)) 107 } 108 return nil 109 } 110 111 func (r *earlyRewriter) expandStar(cursor *sqlparser.Cursor, node sqlparser.SelectExprs) error { 112 currentScope := r.scoper.currentScope() 113 var selExprs sqlparser.SelectExprs 114 changed := false 115 for _, selectExpr := range node { 116 starExpr, isStarExpr := selectExpr.(*sqlparser.StarExpr) 117 if !isStarExpr { 118 selExprs = append(selExprs, selectExpr) 119 continue 120 } 121 starExpanded, colNames, err := r.expandTableColumns(starExpr, currentScope.tables, r.binder.usingJoinInfo, r.scoper.org) 122 if err != nil { 123 return err 124 } 125 if !starExpanded || colNames == nil { 126 selExprs = append(selExprs, selectExpr) 127 continue 128 } 129 selExprs = append(selExprs, colNames...) 130 changed = true 131 } 132 if changed { 133 cursor.ReplaceAndRevisit(selExprs) 134 } 135 return nil 136 } 137 138 // rewriteHavingAndOrderBy rewrites columns on the ORDER BY/HAVING 139 // clauses to use aliases from the SELECT expressions when available. 140 // The scoping rules are: 141 // - A column identifier with no table qualifier that matches an alias introduced 142 // in SELECT points to that expression, and not at any table column 143 // - Except when expression aliased is an aggregation, and the column identifier in the 144 // HAVING/ORDER BY clause is inside an aggregation function 145 // 146 // This is a fucking weird scoping rule, but it's what MySQL seems to do... ¯\_(ツ)_/¯ 147 func rewriteHavingAndOrderBy(node, parent sqlparser.SQLNode) { 148 // TODO - clean up and comment this mess 149 sel, isSel := parent.(*sqlparser.Select) 150 if !isSel { 151 return 152 } 153 154 sqlparser.SafeRewrite(node, func(node, _ sqlparser.SQLNode) bool { 155 _, isSubQ := node.(*sqlparser.Subquery) 156 return !isSubQ 157 }, func(cursor *sqlparser.Cursor) bool { 158 col, ok := cursor.Node().(*sqlparser.ColName) 159 if !ok { 160 return true 161 } 162 if !col.Qualifier.IsEmpty() { 163 return true 164 } 165 _, parentIsAggr := cursor.Parent().(sqlparser.AggrFunc) 166 for _, e := range sel.SelectExprs { 167 ae, ok := e.(*sqlparser.AliasedExpr) 168 if !ok || !ae.As.Equal(col.Name) { 169 continue 170 } 171 _, aliasPointsToAggr := ae.Expr.(sqlparser.AggrFunc) 172 if parentIsAggr && aliasPointsToAggr { 173 return false 174 } 175 176 safeToRewrite := true 177 _ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { 178 switch node.(type) { 179 case *sqlparser.ColName: 180 safeToRewrite = false 181 return false, nil 182 case sqlparser.AggrFunc: 183 return false, nil 184 } 185 return true, nil 186 }, ae.Expr) 187 if safeToRewrite { 188 cursor.Replace(ae.Expr) 189 } 190 } 191 return true 192 }) 193 } 194 195 func (r *earlyRewriter) rewriteOrderByExpr(node *sqlparser.Literal) (sqlparser.Expr, error) { 196 currScope, found := r.scoper.specialExprScopes[node] 197 if !found { 198 return nil, nil 199 } 200 num, err := strconv.Atoi(node.Val) 201 if err != nil { 202 return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "error parsing column number: %s", node.Val) 203 } 204 stmt, isSel := currScope.stmt.(*sqlparser.Select) 205 if !isSel { 206 return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "error invalid statement type, expect Select, got: %T", currScope.stmt) 207 } 208 209 if num < 1 || num > len(stmt.SelectExprs) { 210 return nil, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.BadFieldError, "Unknown column '%d' in '%s'", num, r.clause) 211 } 212 213 for i := 0; i < num; i++ { 214 expr := stmt.SelectExprs[i] 215 _, ok := expr.(*sqlparser.AliasedExpr) 216 if !ok { 217 return nil, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "cannot use column offsets in %s when using `%s`", r.clause, sqlparser.String(expr)) 218 } 219 } 220 221 aliasedExpr, ok := stmt.SelectExprs[num-1].(*sqlparser.AliasedExpr) 222 if !ok { 223 return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "don't know how to handle %s", sqlparser.String(node)) 224 } 225 226 if !aliasedExpr.As.IsEmpty() { 227 return sqlparser.NewColName(aliasedExpr.As.String()), nil 228 } 229 230 expr := realCloneOfColNames(aliasedExpr.Expr, currScope.isUnion) 231 return expr, nil 232 } 233 234 // realCloneOfColNames clones all the expressions including ColName. 235 // Since sqlparser.CloneRefOfColName does not clone col names, this method is needed. 236 func realCloneOfColNames(expr sqlparser.Expr, union bool) sqlparser.Expr { 237 return sqlparser.CopyOnRewrite(expr, nil, func(cursor *sqlparser.CopyOnWriteCursor) { 238 exp, ok := cursor.Node().(*sqlparser.ColName) 239 if !ok { 240 return 241 } 242 243 newColName := *exp 244 if union { 245 newColName.Qualifier = sqlparser.TableName{} 246 } 247 cursor.Replace(&newColName) 248 }, nil).(sqlparser.Expr) 249 } 250 251 func rewriteOrFalse(orExpr sqlparser.OrExpr) sqlparser.Expr { 252 // we are looking for the pattern `WHERE c = 1 OR 1 = 0` 253 isFalse := func(subExpr sqlparser.Expr) bool { 254 evalEnginePred, err := evalengine.Translate(subExpr, nil) 255 if err != nil { 256 return false 257 } 258 259 env := evalengine.EmptyExpressionEnv() 260 res, err := env.Evaluate(evalEnginePred) 261 if err != nil { 262 return false 263 } 264 265 boolValue, err := res.Value().ToBool() 266 if err != nil { 267 return false 268 } 269 270 return !boolValue 271 } 272 273 if isFalse(orExpr.Left) { 274 return orExpr.Right 275 } else if isFalse(orExpr.Right) { 276 return orExpr.Left 277 } 278 279 return nil 280 } 281 282 func rewriteJoinUsing( 283 current *scope, 284 using sqlparser.Columns, 285 org originable, 286 ) error { 287 joinUsing := current.prepareUsingMap() 288 predicates := make([]sqlparser.Expr, 0, len(using)) 289 for _, column := range using { 290 var foundTables []sqlparser.TableName 291 for _, tbl := range current.tables { 292 if !tbl.authoritative() { 293 return vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "can't handle JOIN USING without authoritative tables") 294 } 295 296 currTable := tbl.getTableSet(org) 297 usingCols := joinUsing[currTable] 298 if usingCols == nil { 299 usingCols = map[string]TableSet{} 300 } 301 for _, col := range tbl.getColumns() { 302 _, found := usingCols[strings.ToLower(col.Name)] 303 if found { 304 tblName, err := tbl.Name() 305 if err != nil { 306 return err 307 } 308 309 foundTables = append(foundTables, tblName) 310 break // no need to look at other columns in this table 311 } 312 } 313 } 314 for i, lft := range foundTables { 315 for j := i + 1; j < len(foundTables); j++ { 316 rgt := foundTables[j] 317 predicates = append(predicates, &sqlparser.ComparisonExpr{ 318 Operator: sqlparser.EqualOp, 319 Left: sqlparser.NewColNameWithQualifier(column.String(), lft), 320 Right: sqlparser.NewColNameWithQualifier(column.String(), rgt), 321 }) 322 } 323 } 324 } 325 326 // now, we go up the scope until we find a SELECT with a where clause we can add this predicate to 327 for current != nil { 328 sel, found := current.stmt.(*sqlparser.Select) 329 if found { 330 if sel.Where == nil { 331 sel.Where = &sqlparser.Where{ 332 Type: sqlparser.WhereClause, 333 Expr: sqlparser.AndExpressions(predicates...), 334 } 335 } else { 336 sel.Where.Expr = sqlparser.AndExpressions(append(predicates, sel.Where.Expr)...) 337 } 338 return nil 339 } 340 current = current.parent 341 } 342 return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "did not find WHERE clause") 343 } 344 345 func (r *earlyRewriter) expandTableColumns( 346 starExpr *sqlparser.StarExpr, 347 tables []TableInfo, 348 joinUsing map[TableSet]map[string]TableSet, 349 org originable, 350 ) (bool, sqlparser.SelectExprs, error) { 351 unknownTbl := true 352 var colNames sqlparser.SelectExprs 353 starExpanded := true 354 expandedColumns := map[sqlparser.TableName][]*sqlparser.ColName{} 355 for _, tbl := range tables { 356 if !starExpr.TableName.IsEmpty() && !tbl.matches(starExpr.TableName) { 357 continue 358 } 359 unknownTbl = false 360 if !tbl.authoritative() { 361 starExpanded = false 362 break 363 } 364 tblName, err := tbl.Name() 365 if err != nil { 366 return false, nil, err 367 } 368 369 needsQualifier := len(tables) > 1 370 tableAliased := !tbl.getExpr().As.IsEmpty() 371 withQualifier := needsQualifier || tableAliased 372 currTable := tbl.getTableSet(org) 373 usingCols := joinUsing[currTable] 374 if usingCols == nil { 375 usingCols = map[string]TableSet{} 376 } 377 378 addColName := func(col ColumnInfo) { 379 var colName *sqlparser.ColName 380 var alias sqlparser.IdentifierCI 381 if withQualifier { 382 colName = sqlparser.NewColNameWithQualifier(col.Name, tblName) 383 } else { 384 colName = sqlparser.NewColName(col.Name) 385 } 386 if needsQualifier { 387 alias = sqlparser.NewIdentifierCI(col.Name) 388 } 389 colNames = append(colNames, &sqlparser.AliasedExpr{Expr: colName, As: alias}) 390 vt := tbl.GetVindexTable() 391 if vt != nil { 392 keyspace := vt.Keyspace 393 var ks sqlparser.IdentifierCS 394 if keyspace != nil { 395 ks = sqlparser.NewIdentifierCS(keyspace.Name) 396 } 397 tblName := sqlparser.TableName{ 398 Name: tblName.Name, 399 Qualifier: ks, 400 } 401 expandedColumns[tblName] = append(expandedColumns[tblName], colName) 402 } 403 } 404 405 /* 406 Redundant column elimination and column ordering occurs according to standard SQL, producing this display order: 407 * First, coalesced common columns of the two joined tables, in the order in which they occur in the first table 408 * Second, columns unique to the first table, in order in which they occur in that table 409 * Third, columns unique to the second table, in order in which they occur in that table 410 411 From: https://dev.mysql.com/doc/refman/8.0/en/join.html 412 */ 413 outer: 414 // in this first loop we just find columns used in any JOIN USING used on this table 415 for _, col := range tbl.getColumns() { 416 ts, found := usingCols[col.Name] 417 if found { 418 for i, ts := range ts.Constituents() { 419 if ts == currTable { 420 if i == 0 { 421 addColName(col) 422 } else { 423 continue outer 424 } 425 } 426 } 427 } 428 } 429 430 // and this time around we are printing any columns not involved in any JOIN USING 431 for _, col := range tbl.getColumns() { 432 if ts, found := usingCols[col.Name]; found && currTable.IsSolvedBy(ts) { 433 continue 434 } 435 436 addColName(col) 437 } 438 } 439 440 if unknownTbl { 441 // This will only happen for case when starExpr has qualifier. 442 return false, nil, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.BadDb, "Unknown table '%s'", sqlparser.String(starExpr.TableName)) 443 } 444 if starExpanded { 445 r.expandedColumns = expandedColumns 446 } 447 return starExpanded, colNames, nil 448 }