vitess.io/vitess@v0.16.2/go/vt/vtgate/simplifier/simplifier.go (about) 1 /* 2 Copyright 2021 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package simplifier 18 19 import ( 20 "vitess.io/vitess/go/vt/log" 21 "vitess.io/vitess/go/vt/sqlparser" 22 "vitess.io/vitess/go/vt/vtgate/semantics" 23 ) 24 25 // SimplifyStatement simplifies the AST of a query. It basically iteratively prunes leaves of the AST, as long as the pruning 26 // continues to return true from the `test` function. 27 func SimplifyStatement( 28 in sqlparser.SelectStatement, 29 currentDB string, 30 si semantics.SchemaInformation, 31 testF func(sqlparser.SelectStatement) bool, 32 ) sqlparser.SelectStatement { 33 tables, err := getTables(in, currentDB, si) 34 if err != nil { 35 panic(err) 36 } 37 38 test := func(s sqlparser.SelectStatement) bool { 39 // Since our semantic analysis changes the AST, we clone it first, so we have a pristine AST to play with 40 return testF(sqlparser.CloneSelectStatement(s)) 41 } 42 43 if success := trySimplifyUnions(sqlparser.CloneSelectStatement(in), test); success != nil { 44 return SimplifyStatement(success, currentDB, si, testF) 45 } 46 47 // first we try to simplify the query by removing any table. 48 // If we can remove a table and all uses of it, that's a good start 49 if success := tryRemoveTable(tables, sqlparser.CloneSelectStatement(in), currentDB, si, testF); success != nil { 50 return SimplifyStatement(success, currentDB, si, testF) 51 } 52 53 // now let's try to simplify * expressions 54 if success := simplifyStarExpr(sqlparser.CloneSelectStatement(in), test); success != nil { 55 return SimplifyStatement(success, currentDB, si, testF) 56 } 57 58 // we try to remove select expressions next 59 if success := trySimplifyExpressions(sqlparser.CloneSelectStatement(in), test); success != nil { 60 return SimplifyStatement(success, currentDB, si, testF) 61 } 62 return in 63 } 64 65 func trySimplifyExpressions(in sqlparser.SelectStatement, test func(sqlparser.SelectStatement) bool) sqlparser.SelectStatement { 66 simplified := false 67 visitAllExpressionsInAST(in, func(cursor expressionCursor) bool { 68 // first - let's try to remove the expression 69 if cursor.remove() { 70 if test(in) { 71 log.Errorf("removed expression: %s", sqlparser.String(cursor.expr)) 72 simplified = true 73 return false 74 } 75 cursor.restore() 76 } 77 78 // ok, we seem to need this expression. let's see if we can find a simpler version 79 s := &shrinker{orig: cursor.expr} 80 newExpr := s.Next() 81 for newExpr != nil { 82 cursor.replace(newExpr) 83 if test(in) { 84 log.Errorf("simplified expression: %s -> %s", sqlparser.String(cursor.expr), sqlparser.String(newExpr)) 85 simplified = true 86 return false 87 } 88 newExpr = s.Next() 89 } 90 91 // if we get here, we failed to simplify this expression, 92 // so we put back in the original expression 93 cursor.restore() 94 return true 95 }) 96 97 if simplified { 98 return in 99 } 100 101 return nil 102 } 103 104 func trySimplifyUnions(in sqlparser.SelectStatement, test func(sqlparser.SelectStatement) bool) (res sqlparser.SelectStatement) { 105 106 if union, ok := in.(*sqlparser.Union); ok { 107 // the root object is an UNION 108 if test(sqlparser.CloneSelectStatement(union.Left)) { 109 return union.Left 110 } 111 if test(sqlparser.CloneSelectStatement(union.Right)) { 112 return union.Right 113 } 114 } 115 116 abort := false 117 118 sqlparser.Rewrite(in, func(cursor *sqlparser.Cursor) bool { 119 switch node := cursor.Node().(type) { 120 case *sqlparser.Union: 121 if _, ok := cursor.Parent().(*sqlparser.RootNode); ok { 122 // we have already checked the root node 123 return true 124 } 125 cursor.Replace(node.Left) 126 clone := sqlparser.CloneSelectStatement(in) 127 if test(clone) { 128 log.Errorf("replaced UNION with one of its children") 129 abort = true 130 return true 131 } 132 cursor.Replace(node.Right) 133 clone = sqlparser.CloneSelectStatement(in) 134 if test(clone) { 135 log.Errorf("replaced UNION with one of its children") 136 abort = true 137 return true 138 } 139 cursor.Replace(node) 140 } 141 return true 142 }, func(*sqlparser.Cursor) bool { 143 return !abort 144 }) 145 146 if !abort { 147 // we found no simplifications 148 return nil 149 } 150 return in 151 } 152 153 func tryRemoveTable(tables []semantics.TableInfo, in sqlparser.SelectStatement, currentDB string, si semantics.SchemaInformation, test func(sqlparser.SelectStatement) bool) sqlparser.SelectStatement { 154 // we start by removing one table at a time, and see if we still have an interesting plan 155 for idx, tbl := range tables { 156 clone := sqlparser.CloneSelectStatement(in) 157 searchedTS := semantics.SingleTableSet(idx) 158 simplified := removeTable(clone, searchedTS, currentDB, si) 159 name, _ := tbl.Name() 160 if simplified && test(clone) { 161 log.Errorf("removed table %s", sqlparser.String(name)) 162 return clone 163 } 164 } 165 166 return nil 167 } 168 169 func getTables(in sqlparser.SelectStatement, currentDB string, si semantics.SchemaInformation) ([]semantics.TableInfo, error) { 170 // Since our semantic analysis changes the AST, we clone it first, so we have a pristine AST to play with 171 clone := sqlparser.CloneSelectStatement(in) 172 semTable, err := semantics.Analyze(clone, currentDB, si) 173 if err != nil { 174 return nil, err 175 } 176 return semTable.Tables, nil 177 } 178 179 func simplifyStarExpr(in sqlparser.SelectStatement, test func(sqlparser.SelectStatement) bool) sqlparser.SelectStatement { 180 simplified := false 181 sqlparser.Rewrite(in, func(cursor *sqlparser.Cursor) bool { 182 se, ok := cursor.Node().(*sqlparser.StarExpr) 183 if !ok { 184 return true 185 } 186 cursor.Replace(&sqlparser.AliasedExpr{ 187 Expr: sqlparser.NewIntLiteral("0"), 188 }) 189 if test(in) { 190 log.Errorf("replaced star with literal") 191 simplified = true 192 return false 193 } 194 cursor.Replace(se) 195 196 return true 197 }, nil) 198 if simplified { 199 return in 200 } 201 return nil 202 } 203 204 // removeTable removes the table with the given index from the select statement, which includes the FROM clause 205 // but also all expressions and predicates that depend on the table 206 func removeTable(clone sqlparser.SelectStatement, searchedTS semantics.TableSet, db string, si semantics.SchemaInformation) bool { 207 semTable, err := semantics.Analyze(clone, db, si) 208 if err != nil { 209 panic(err) 210 } 211 212 simplified := true 213 shouldKeepExpr := func(expr sqlparser.Expr) bool { 214 return !semTable.RecursiveDeps(expr).IsOverlapping(searchedTS) || sqlparser.ContainsAggregation(expr) 215 } 216 sqlparser.Rewrite(clone, func(cursor *sqlparser.Cursor) bool { 217 switch node := cursor.Node().(type) { 218 case *sqlparser.JoinTableExpr: 219 lft, ok := node.LeftExpr.(*sqlparser.AliasedTableExpr) 220 if ok { 221 ts := semTable.TableSetFor(lft) 222 if searchedTS == ts { 223 cursor.Replace(node.RightExpr) 224 } 225 } 226 rgt, ok := node.RightExpr.(*sqlparser.AliasedTableExpr) 227 if ok { 228 ts := semTable.TableSetFor(rgt) 229 if searchedTS == ts { 230 cursor.Replace(node.LeftExpr) 231 } 232 } 233 case *sqlparser.Select: 234 if len(node.From) == 1 { 235 _, notJoin := node.From[0].(*sqlparser.AliasedTableExpr) 236 if notJoin { 237 simplified = false 238 return false 239 } 240 } 241 for i, tbl := range node.From { 242 lft, ok := tbl.(*sqlparser.AliasedTableExpr) 243 if ok { 244 ts := semTable.TableSetFor(lft) 245 if searchedTS == ts { 246 node.From = append(node.From[:i], node.From[i+1:]...) 247 return true 248 } 249 } 250 } 251 case *sqlparser.Where: 252 exprs := sqlparser.SplitAndExpression(nil, node.Expr) 253 var newPredicate sqlparser.Expr 254 for _, expr := range exprs { 255 if !semTable.RecursiveDeps(expr).IsOverlapping(searchedTS) { 256 newPredicate = sqlparser.AndExpressions(newPredicate, expr) 257 } 258 } 259 node.Expr = newPredicate 260 case sqlparser.SelectExprs: 261 _, isSel := cursor.Parent().(*sqlparser.Select) 262 if !isSel { 263 return true 264 } 265 266 var newExprs sqlparser.SelectExprs 267 for _, ae := range node { 268 expr, ok := ae.(*sqlparser.AliasedExpr) 269 if !ok { 270 newExprs = append(newExprs, ae) 271 continue 272 } 273 if shouldKeepExpr(expr.Expr) { 274 newExprs = append(newExprs, ae) 275 } 276 } 277 cursor.Replace(newExprs) 278 case sqlparser.GroupBy: 279 var newExprs sqlparser.GroupBy 280 for _, expr := range node { 281 if shouldKeepExpr(expr) { 282 newExprs = append(newExprs, expr) 283 } 284 } 285 cursor.Replace(newExprs) 286 case sqlparser.OrderBy: 287 var newExprs sqlparser.OrderBy 288 for _, expr := range node { 289 if shouldKeepExpr(expr.Expr) { 290 newExprs = append(newExprs, expr) 291 } 292 } 293 294 cursor.Replace(newExprs) 295 } 296 return true 297 }, nil) 298 return simplified 299 } 300 301 type expressionCursor struct { 302 expr sqlparser.Expr 303 replace func(replaceWith sqlparser.Expr) 304 remove func() bool 305 restore func() 306 } 307 308 func newExprCursor(expr sqlparser.Expr, replace func(replaceWith sqlparser.Expr), remove func() bool, restore func()) expressionCursor { 309 return expressionCursor{ 310 expr: expr, 311 replace: replace, 312 remove: remove, 313 restore: restore, 314 } 315 } 316 317 // visitAllExpressionsInAST will walk the AST and visit all expressions 318 // This cursor has a few extra capabilities that the normal sqlparser.Rewrite does not have, 319 // such as visiting and being able to change individual expressions in a AND tree 320 func visitAllExpressionsInAST(clone sqlparser.SelectStatement, visit func(expressionCursor) bool) { 321 abort := false 322 post := func(*sqlparser.Cursor) bool { 323 return !abort 324 } 325 pre := func(cursor *sqlparser.Cursor) bool { 326 if abort { 327 return true 328 } 329 switch node := cursor.Node().(type) { 330 case sqlparser.SelectExprs: 331 _, isSel := cursor.Parent().(*sqlparser.Select) 332 if !isSel { 333 return true 334 } 335 for idx := 0; idx < len(node); idx++ { 336 ae := node[idx] 337 expr, ok := ae.(*sqlparser.AliasedExpr) 338 if !ok { 339 continue 340 } 341 removed := false 342 original := sqlparser.CloneExpr(expr.Expr) 343 item := newExprCursor( 344 expr.Expr, 345 /*replace*/ func(replaceWith sqlparser.Expr) { 346 if removed { 347 panic("cant replace after remove without restore") 348 } 349 expr.Expr = replaceWith 350 }, 351 /*remove*/ func() bool { 352 if removed { 353 panic("can't remove twice, silly") 354 } 355 if len(node) == 1 { 356 // can't remove the last expressions - we'd end up with an empty SELECT clause 357 return false 358 } 359 withoutElement := append(node[:idx], node[idx+1:]...) 360 cursor.Replace(withoutElement) 361 node = withoutElement 362 removed = true 363 return true 364 }, 365 /*restore*/ func() { 366 if removed { 367 front := make(sqlparser.SelectExprs, idx) 368 copy(front, node[:idx]) 369 back := make(sqlparser.SelectExprs, len(node)-idx) 370 copy(back, node[idx:]) 371 frontWithRestoredExpr := append(front, ae) 372 node = append(frontWithRestoredExpr, back...) 373 cursor.Replace(node) 374 removed = false 375 return 376 } 377 expr.Expr = original 378 }, 379 ) 380 abort = !visit(item) 381 } 382 case *sqlparser.Where: 383 exprs := sqlparser.SplitAndExpression(nil, node.Expr) 384 set := func(input []sqlparser.Expr) { 385 node.Expr = sqlparser.AndExpressions(input...) 386 exprs = input 387 } 388 abort = !visitExpressions(exprs, set, visit) 389 case *sqlparser.JoinCondition: 390 join, ok := cursor.Parent().(*sqlparser.JoinTableExpr) 391 if !ok { 392 return true 393 } 394 if join.Join != sqlparser.NormalJoinType || node.Using != nil { 395 return false 396 } 397 exprs := sqlparser.SplitAndExpression(nil, node.On) 398 set := func(input []sqlparser.Expr) { 399 node.On = sqlparser.AndExpressions(input...) 400 exprs = input 401 } 402 abort = !visitExpressions(exprs, set, visit) 403 case sqlparser.GroupBy: 404 set := func(input []sqlparser.Expr) { 405 node = input 406 cursor.Replace(node) 407 } 408 abort = !visitExpressions(node, set, visit) 409 case sqlparser.OrderBy: 410 for idx := 0; idx < len(node); idx++ { 411 order := node[idx] 412 removed := false 413 original := sqlparser.CloneExpr(order.Expr) 414 item := newExprCursor( 415 order.Expr, 416 /*replace*/ func(replaceWith sqlparser.Expr) { 417 if removed { 418 panic("cant replace after remove without restore") 419 } 420 order.Expr = replaceWith 421 }, 422 /*remove*/ func() bool { 423 if removed { 424 panic("can't remove twice, silly") 425 } 426 withoutElement := append(node[:idx], node[idx+1:]...) 427 if len(withoutElement) == 0 { 428 var nilVal sqlparser.OrderBy // this is used to create a typed nil value 429 cursor.Replace(nilVal) 430 } else { 431 cursor.Replace(withoutElement) 432 } 433 node = withoutElement 434 removed = true 435 return true 436 }, 437 /*restore*/ func() { 438 if removed { 439 front := make(sqlparser.OrderBy, idx) 440 copy(front, node[:idx]) 441 back := make(sqlparser.OrderBy, len(node)-idx) 442 copy(back, node[idx:]) 443 frontWithRestoredExpr := append(front, order) 444 node = append(frontWithRestoredExpr, back...) 445 cursor.Replace(node) 446 removed = false 447 return 448 } 449 order.Expr = original 450 }, 451 ) 452 abort = visit(item) 453 if abort { 454 break 455 } 456 } 457 case *sqlparser.Limit: 458 if node.Offset != nil { 459 original := node.Offset 460 cursor := newExprCursor(node.Offset, 461 /*replace*/ func(replaceWith sqlparser.Expr) { 462 node.Offset = replaceWith 463 }, 464 /*remove*/ func() bool { 465 node.Offset = nil 466 return true 467 }, 468 /*restore*/ func() { 469 node.Offset = original 470 }) 471 abort = visit(cursor) 472 } 473 if !abort && node.Rowcount != nil { 474 original := node.Rowcount 475 cursor := newExprCursor(node.Rowcount, 476 /*replace*/ func(replaceWith sqlparser.Expr) { 477 node.Rowcount = replaceWith 478 }, 479 /*remove*/ func() bool { 480 // removing Rowcount is an invalid op 481 return false 482 }, 483 /*restore*/ func() { 484 node.Rowcount = original 485 }) 486 abort = visit(cursor) 487 } 488 } 489 return true 490 } 491 sqlparser.Rewrite(clone, pre, post) 492 } 493 494 // visitExpressions allows the cursor to visit all expressions in a slice, 495 // and can replace or remove items and restore the slice. 496 func visitExpressions( 497 exprs []sqlparser.Expr, 498 set func(input []sqlparser.Expr), 499 visit func(expressionCursor) bool, 500 ) bool { 501 for idx := 0; idx < len(exprs); idx++ { 502 expr := exprs[idx] 503 removed := false 504 item := newExprCursor(expr, 505 func(replaceWith sqlparser.Expr) { 506 if removed { 507 panic("cant replace after remove without restore") 508 } 509 exprs[idx] = replaceWith 510 set(exprs) 511 }, 512 /*remove*/ func() bool { 513 if removed { 514 panic("can't remove twice, silly") 515 } 516 exprs = append(exprs[:idx], exprs[idx+1:]...) 517 set(exprs) 518 removed = true 519 return true 520 }, 521 /*restore*/ func() { 522 if removed { 523 front := make([]sqlparser.Expr, idx) 524 copy(front, exprs[:idx]) 525 back := make([]sqlparser.Expr, len(exprs)-idx) 526 copy(back, exprs[idx:]) 527 frontWithRestoredExpr := append(front, expr) 528 exprs = append(frontWithRestoredExpr, back...) 529 set(exprs) 530 removed = false 531 return 532 } 533 exprs[idx] = expr 534 set(exprs) 535 }) 536 if !visit(item) { 537 return false 538 } 539 } 540 return true 541 }