vitess.io/vitess@v0.16.2/go/vt/vtgate/semantics/analyzer_test.go (about) 1 /* 2 Copyright 2021 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package semantics 18 19 import ( 20 "fmt" 21 "testing" 22 23 "vitess.io/vitess/go/vt/vtgate/engine" 24 25 "github.com/stretchr/testify/assert" 26 "github.com/stretchr/testify/require" 27 28 querypb "vitess.io/vitess/go/vt/proto/query" 29 "vitess.io/vitess/go/vt/sqlparser" 30 "vitess.io/vitess/go/vt/vtgate/vindexes" 31 ) 32 33 var T0 TableSet 34 35 var ( 36 // Just here to make outputs more readable 37 None = EmptyTableSet() 38 T1 = SingleTableSet(0) 39 T2 = SingleTableSet(1) 40 T3 = SingleTableSet(2) 41 T4 = SingleTableSet(3) 42 T5 = SingleTableSet(4) 43 ) 44 45 func extract(in *sqlparser.Select, idx int) sqlparser.Expr { 46 return in.SelectExprs[idx].(*sqlparser.AliasedExpr).Expr 47 } 48 49 func TestBindingSingleTablePositive(t *testing.T) { 50 queries := []string{ 51 "select col from tabl", 52 "select uid from t2", 53 "select tabl.col from tabl", 54 "select d.tabl.col from tabl", 55 "select col from d.tabl", 56 "select tabl.col from d.tabl", 57 "select d.tabl.col from d.tabl", 58 "select col+col from tabl", 59 "select max(col1+col2) from d.tabl", 60 "select max(id) from t1", 61 } 62 for _, query := range queries { 63 t.Run(query, func(t *testing.T) { 64 stmt, semTable := parseAndAnalyze(t, query, "d") 65 sel, _ := stmt.(*sqlparser.Select) 66 t1 := sel.From[0].(*sqlparser.AliasedTableExpr) 67 ts := semTable.TableSetFor(t1) 68 assert.Equal(t, SingleTableSet(0), ts) 69 70 recursiveDeps := semTable.RecursiveDeps(extract(sel, 0)) 71 assert.Equal(t, T1, recursiveDeps, query) 72 assert.Equal(t, T1, semTable.DirectDeps(extract(sel, 0)), query) 73 assert.Equal(t, 1, recursiveDeps.NumberOfTables(), "number of tables is wrong") 74 }) 75 } 76 } 77 78 func TestInformationSchemaColumnInfo(t *testing.T) { 79 stmt, semTable := parseAndAnalyze(t, "select table_comment, file_name from information_schema.`TABLES`, information_schema.`FILES`", "d") 80 81 sel, _ := stmt.(*sqlparser.Select) 82 tables := SingleTableSet(0) 83 files := SingleTableSet(1) 84 85 assert.Equal(t, tables, semTable.RecursiveDeps(extract(sel, 0))) 86 assert.Equal(t, files, semTable.DirectDeps(extract(sel, 1))) 87 } 88 89 func TestBindingSingleAliasedTablePositive(t *testing.T) { 90 queries := []string{ 91 "select col from tabl as X", 92 "select tabl.col from X as tabl", 93 "select col from d.X as tabl", 94 "select tabl.col from d.X as tabl", 95 "select col+col from tabl as X", 96 "select max(tabl.col1 + tabl.col2) from d.X as tabl", 97 "select max(t.id) from t1 as t", 98 } 99 for _, query := range queries { 100 t.Run(query, func(t *testing.T) { 101 stmt, semTable := parseAndAnalyze(t, query, "") 102 sel, _ := stmt.(*sqlparser.Select) 103 t1 := sel.From[0].(*sqlparser.AliasedTableExpr) 104 ts := semTable.TableSetFor(t1) 105 assert.Equal(t, SingleTableSet(0), ts) 106 107 recursiveDeps := semTable.RecursiveDeps(extract(sel, 0)) 108 require.Equal(t, T1, recursiveDeps, query) 109 assert.Equal(t, 1, recursiveDeps.NumberOfTables(), "number of tables is wrong") 110 }) 111 } 112 } 113 114 func TestBindingSingleTableNegative(t *testing.T) { 115 queries := []string{ 116 "select foo.col from tabl", 117 "select ks.tabl.col from tabl", 118 "select ks.tabl.col from d.tabl", 119 "select d.tabl.col from ks.tabl", 120 "select foo.col from d.tabl", 121 "select tabl.col from d.tabl as t", 122 } 123 for _, query := range queries { 124 t.Run(query, func(t *testing.T) { 125 parse, err := sqlparser.Parse(query) 126 require.NoError(t, err) 127 st, err := Analyze(parse.(sqlparser.SelectStatement), "d", &FakeSI{}) 128 require.NoError(t, err) 129 require.ErrorContains(t, st.NotUnshardedErr, "symbol") 130 require.ErrorContains(t, st.NotUnshardedErr, "not found") 131 }) 132 } 133 } 134 135 func TestBindingSingleAliasedTableNegative(t *testing.T) { 136 queries := []string{ 137 "select tabl.col from tabl as X", 138 "select d.X.col from d.X as tabl", 139 "select d.tabl.col from X as tabl", 140 "select d.tabl.col from ks.X as tabl", 141 "select d.tabl.col from d.X as tabl", 142 } 143 for _, query := range queries { 144 t.Run(query, func(t *testing.T) { 145 parse, err := sqlparser.Parse(query) 146 require.NoError(t, err) 147 st, err := Analyze(parse.(sqlparser.SelectStatement), "", &FakeSI{ 148 Tables: map[string]*vindexes.Table{ 149 "t": {Name: sqlparser.NewIdentifierCS("t")}, 150 }, 151 }) 152 require.NoError(t, err) 153 require.Error(t, st.NotUnshardedErr) 154 }) 155 } 156 } 157 158 func TestBindingMultiTablePositive(t *testing.T) { 159 type testCase struct { 160 query string 161 deps TableSet 162 numberOfTables int 163 } 164 queries := []testCase{{ 165 query: "select t.col from t, s", 166 deps: T1, 167 numberOfTables: 1, 168 }, { 169 query: "select s.col from t join s", 170 deps: T2, 171 numberOfTables: 1, 172 }, { 173 query: "select max(t.col+s.col) from t, s", 174 deps: MergeTableSets(T1, T2), 175 numberOfTables: 2, 176 }, { 177 query: "select max(t.col+s.col) from t join s", 178 deps: MergeTableSets(T1, T2), 179 numberOfTables: 2, 180 }, { 181 query: "select case t.col when s.col then r.col else u.col end from t, s, r, w, u", 182 deps: MergeTableSets(T1, T2, T3, T5), 183 numberOfTables: 4, 184 // }, { 185 // TODO: move to subquery 186 // make sure that we don't let sub-query dependencies leak out by mistake 187 // query: "select t.col + (select 42 from s) from t", 188 // deps: T1, 189 // }, { 190 // query: "select (select 42 from s where r.id = s.id) from r", 191 // deps: T1 | T2, 192 }, { 193 query: "select u1.a + u2.a from u1, u2", 194 deps: MergeTableSets(T1, T2), 195 numberOfTables: 2, 196 }} 197 for _, query := range queries { 198 t.Run(query.query, func(t *testing.T) { 199 stmt, semTable := parseAndAnalyze(t, query.query, "user") 200 sel, _ := stmt.(*sqlparser.Select) 201 recursiveDeps := semTable.RecursiveDeps(extract(sel, 0)) 202 assert.Equal(t, query.deps, recursiveDeps, query.query) 203 assert.Equal(t, query.numberOfTables, recursiveDeps.NumberOfTables(), "number of tables is wrong") 204 }) 205 } 206 } 207 208 func TestBindingMultiAliasedTablePositive(t *testing.T) { 209 type testCase struct { 210 query string 211 deps TableSet 212 numberOfTables int 213 } 214 queries := []testCase{{ 215 query: "select X.col from t as X, s as S", 216 deps: T1, 217 numberOfTables: 1, 218 }, { 219 query: "select X.col+S.col from t as X, s as S", 220 deps: MergeTableSets(T1, T2), 221 numberOfTables: 2, 222 }, { 223 query: "select max(X.col+S.col) from t as X, s as S", 224 deps: MergeTableSets(T1, T2), 225 numberOfTables: 2, 226 }, { 227 query: "select max(X.col+s.col) from t as X, s", 228 deps: MergeTableSets(T1, T2), 229 numberOfTables: 2, 230 }} 231 for _, query := range queries { 232 t.Run(query.query, func(t *testing.T) { 233 stmt, semTable := parseAndAnalyze(t, query.query, "user") 234 sel, _ := stmt.(*sqlparser.Select) 235 recursiveDeps := semTable.RecursiveDeps(extract(sel, 0)) 236 assert.Equal(t, query.deps, recursiveDeps, query.query) 237 assert.Equal(t, query.numberOfTables, recursiveDeps.NumberOfTables(), "number of tables is wrong") 238 }) 239 } 240 } 241 242 func TestBindingMultiTableNegative(t *testing.T) { 243 queries := []string{ 244 "select 1 from d.tabl, d.tabl", 245 "select 1 from d.tabl, tabl", 246 "select t.col from k.t, t", 247 "select b.t.col from b.t, t", 248 } 249 for _, query := range queries { 250 t.Run(query, func(t *testing.T) { 251 parse, err := sqlparser.Parse(query) 252 require.NoError(t, err) 253 _, err = Analyze(parse.(sqlparser.SelectStatement), "d", &FakeSI{ 254 Tables: map[string]*vindexes.Table{ 255 "tabl": {Name: sqlparser.NewIdentifierCS("tabl")}, 256 "foo": {Name: sqlparser.NewIdentifierCS("foo")}, 257 }, 258 }) 259 require.Error(t, err) 260 }) 261 } 262 } 263 264 func TestBindingMultiAliasedTableNegative(t *testing.T) { 265 queries := []string{ 266 "select 1 from d.tabl as tabl, d.tabl", 267 "select 1 from d.tabl as tabl, tabl", 268 "select 1 from d.tabl as a, tabl as a", 269 "select 1 from user join user_extra user", 270 "select t.col from k.t as t, t", 271 "select b.t.col from b.t as t, t", 272 } 273 for _, query := range queries { 274 t.Run(query, func(t *testing.T) { 275 parse, err := sqlparser.Parse(query) 276 require.NoError(t, err) 277 _, err = Analyze(parse.(sqlparser.SelectStatement), "d", &FakeSI{ 278 Tables: map[string]*vindexes.Table{ 279 "tabl": {Name: sqlparser.NewIdentifierCS("tabl")}, 280 "foo": {Name: sqlparser.NewIdentifierCS("foo")}, 281 }, 282 }) 283 require.Error(t, err) 284 }) 285 } 286 } 287 288 func TestNotUniqueTableName(t *testing.T) { 289 queries := []string{ 290 "select * from t, t", 291 "select * from t, (select 1 from x) as t", 292 "select * from t join t", 293 "select * from t join (select 1 from x) as t", 294 } 295 296 for _, query := range queries { 297 t.Run(query, func(t *testing.T) { 298 parse, _ := sqlparser.Parse(query) 299 _, err := Analyze(parse.(sqlparser.SelectStatement), "test", &FakeSI{}) 300 require.Error(t, err) 301 require.Contains(t, err.Error(), "VT03013: not unique table/alias") 302 }) 303 } 304 } 305 306 func TestMissingTable(t *testing.T) { 307 queries := []string{ 308 "select t.col from a", 309 } 310 311 for _, query := range queries { 312 t.Run(query, func(t *testing.T) { 313 parse, _ := sqlparser.Parse(query) 314 st, err := Analyze(parse.(sqlparser.SelectStatement), "", &FakeSI{}) 315 require.NoError(t, err) 316 require.ErrorContains(t, st.NotUnshardedErr, "symbol t.col not found") 317 }) 318 } 319 } 320 321 func TestUnknownColumnMap2(t *testing.T) { 322 varchar := querypb.Type_VARCHAR 323 integer := querypb.Type_INT32 324 325 authoritativeTblA := vindexes.Table{ 326 Name: sqlparser.NewIdentifierCS("a"), 327 Columns: []vindexes.Column{{ 328 Name: sqlparser.NewIdentifierCI("col2"), 329 Type: varchar, 330 }}, 331 ColumnListAuthoritative: true, 332 } 333 authoritativeTblB := vindexes.Table{ 334 Name: sqlparser.NewIdentifierCS("b"), 335 Columns: []vindexes.Column{{ 336 Name: sqlparser.NewIdentifierCI("col"), 337 Type: varchar, 338 }}, 339 ColumnListAuthoritative: true, 340 } 341 nonAuthoritativeTblA := authoritativeTblA 342 nonAuthoritativeTblA.ColumnListAuthoritative = false 343 nonAuthoritativeTblB := authoritativeTblB 344 nonAuthoritativeTblB.ColumnListAuthoritative = false 345 authoritativeTblAWithConflict := vindexes.Table{ 346 Name: sqlparser.NewIdentifierCS("a"), 347 Columns: []vindexes.Column{{ 348 Name: sqlparser.NewIdentifierCI("col"), 349 Type: integer, 350 }}, 351 ColumnListAuthoritative: true, 352 } 353 authoritativeTblBWithInt := vindexes.Table{ 354 Name: sqlparser.NewIdentifierCS("b"), 355 Columns: []vindexes.Column{{ 356 Name: sqlparser.NewIdentifierCI("col"), 357 Type: integer, 358 }}, 359 ColumnListAuthoritative: true, 360 } 361 362 tests := []struct { 363 name string 364 schema map[string]*vindexes.Table 365 err bool 366 typ *querypb.Type 367 }{{ 368 name: "no info about tables", 369 schema: map[string]*vindexes.Table{"a": {}, "b": {}}, 370 err: true, 371 }, { 372 name: "non authoritative columns", 373 schema: map[string]*vindexes.Table{"a": &nonAuthoritativeTblA, "b": &nonAuthoritativeTblA}, 374 err: true, 375 }, { 376 name: "non authoritative columns - one authoritative and one not", 377 schema: map[string]*vindexes.Table{"a": &nonAuthoritativeTblA, "b": &authoritativeTblB}, 378 err: false, 379 typ: &varchar, 380 }, { 381 name: "non authoritative columns - one authoritative and one not", 382 schema: map[string]*vindexes.Table{"a": &authoritativeTblA, "b": &nonAuthoritativeTblB}, 383 err: false, 384 typ: &varchar, 385 }, { 386 name: "authoritative columns", 387 schema: map[string]*vindexes.Table{"a": &authoritativeTblA, "b": &authoritativeTblB}, 388 err: false, 389 typ: &varchar, 390 }, { 391 name: "authoritative columns", 392 schema: map[string]*vindexes.Table{"a": &authoritativeTblA, "b": &authoritativeTblBWithInt}, 393 err: false, 394 typ: &integer, 395 }, { 396 name: "authoritative columns with overlap", 397 schema: map[string]*vindexes.Table{"a": &authoritativeTblAWithConflict, "b": &authoritativeTblB}, 398 err: true, 399 }} 400 401 queries := []string{"select col from a, b", "select col from a as user, b as extra"} 402 for _, query := range queries { 403 t.Run(query, func(t *testing.T) { 404 parse, _ := sqlparser.Parse(query) 405 expr := extract(parse.(*sqlparser.Select), 0) 406 407 for _, test := range tests { 408 t.Run(test.name, func(t *testing.T) { 409 si := &FakeSI{Tables: test.schema} 410 tbl, err := Analyze(parse.(sqlparser.SelectStatement), "", si) 411 if test.err { 412 require.True(t, err != nil || tbl.NotSingleRouteErr != nil) 413 } else { 414 require.NoError(t, err) 415 require.NoError(t, tbl.NotSingleRouteErr) 416 typ := tbl.TypeFor(expr) 417 assert.Equal(t, test.typ, typ) 418 } 419 }) 420 } 421 }) 422 } 423 } 424 425 func TestUnknownPredicate(t *testing.T) { 426 query := "select 1 from a, b where col = 1" 427 authoritativeTblA := &vindexes.Table{ 428 Name: sqlparser.NewIdentifierCS("a"), 429 } 430 authoritativeTblB := &vindexes.Table{ 431 Name: sqlparser.NewIdentifierCS("b"), 432 } 433 434 parse, _ := sqlparser.Parse(query) 435 436 tests := []struct { 437 name string 438 schema map[string]*vindexes.Table 439 err bool 440 }{ 441 { 442 name: "no info about tables", 443 schema: map[string]*vindexes.Table{"a": authoritativeTblA, "b": authoritativeTblB}, 444 err: false, 445 }, 446 } 447 for _, test := range tests { 448 t.Run(test.name, func(t *testing.T) { 449 si := &FakeSI{Tables: test.schema} 450 _, err := Analyze(parse.(sqlparser.SelectStatement), "", si) 451 if test.err { 452 require.Error(t, err) 453 } else { 454 require.NoError(t, err) 455 } 456 }) 457 } 458 } 459 460 func TestScoping(t *testing.T) { 461 queries := []struct { 462 query string 463 errorMessage string 464 }{ 465 { 466 query: "select 1 from u1, u2 left join u3 on u1.a = u2.a", 467 errorMessage: "symbol u1.a not found", 468 }, 469 } 470 for _, query := range queries { 471 t.Run(query.query, func(t *testing.T) { 472 parse, err := sqlparser.Parse(query.query) 473 require.NoError(t, err) 474 st, err := Analyze(parse.(sqlparser.SelectStatement), "user", &FakeSI{ 475 Tables: map[string]*vindexes.Table{ 476 "t": {Name: sqlparser.NewIdentifierCS("t")}, 477 }, 478 }) 479 require.NoError(t, err) 480 require.EqualError(t, st.NotUnshardedErr, query.errorMessage) 481 }) 482 } 483 } 484 485 func TestScopeForSubqueries(t *testing.T) { 486 tcases := []struct { 487 sql string 488 deps TableSet 489 }{ 490 { 491 sql: `select t.col1, (select t.col2 from z as t) from x as t`, 492 deps: T2, 493 }, { 494 sql: `select t.col1, (select t.col2 from z) from x as t`, 495 deps: T1, 496 }, { 497 sql: `select t.col1, (select (select z.col2 from y) from z) from x as t`, 498 deps: T2, 499 }, { 500 sql: `select t.col1, (select (select y.col2 from y) from z) from x as t`, 501 deps: None, 502 }, { 503 sql: `select t.col1, (select (select (select (select w.col2 from w) from x) from y) from z) from x as t`, 504 deps: None, 505 }, { 506 sql: `select t.col1, (select id from t) from x as t`, 507 deps: T2, 508 }, 509 } 510 for _, tc := range tcases { 511 t.Run(tc.sql, func(t *testing.T) { 512 stmt, semTable := parseAndAnalyze(t, tc.sql, "d") 513 sel, _ := stmt.(*sqlparser.Select) 514 515 // extract the first expression from the subquery (which should be the second expression in the outer query) 516 sel2 := sel.SelectExprs[1].(*sqlparser.AliasedExpr).Expr.(*sqlparser.Subquery).Select.(*sqlparser.Select) 517 exp := extract(sel2, 0) 518 s1 := semTable.RecursiveDeps(exp) 519 require.NoError(t, semTable.NotSingleRouteErr) 520 // if scoping works as expected, we should be able to see the inner table being used by the inner expression 521 assert.Equal(t, tc.deps, s1) 522 }) 523 } 524 } 525 526 func TestSubqueriesMappingWhereClause(t *testing.T) { 527 tcs := []struct { 528 sql string 529 opCode engine.PulloutOpcode 530 otherSideName string 531 }{ 532 { 533 sql: "select id from t1 where id in (select uid from t2)", 534 opCode: engine.PulloutIn, 535 otherSideName: "id", 536 }, 537 { 538 sql: "select id from t1 where id not in (select uid from t2)", 539 opCode: engine.PulloutNotIn, 540 otherSideName: "id", 541 }, 542 { 543 sql: "select id from t where col1 = (select uid from t2 order by uid desc limit 1)", 544 opCode: engine.PulloutValue, 545 otherSideName: "col1", 546 }, 547 { 548 sql: "select id from t where exists (select uid from t2 where uid = 42)", 549 opCode: engine.PulloutExists, 550 otherSideName: "", 551 }, 552 { 553 sql: "select id from t where col1 >= (select uid from t2 where uid = 42)", 554 opCode: engine.PulloutValue, 555 otherSideName: "col1", 556 }, 557 } 558 559 for i, tc := range tcs { 560 t.Run(fmt.Sprintf("%d_%s", i+1, tc.sql), func(t *testing.T) { 561 stmt, semTable := parseAndAnalyze(t, tc.sql, "d") 562 sel, _ := stmt.(*sqlparser.Select) 563 564 var subq *sqlparser.Subquery 565 switch whereExpr := sel.Where.Expr.(type) { 566 case *sqlparser.ComparisonExpr: 567 subq = whereExpr.Right.(*sqlparser.Subquery) 568 case *sqlparser.ExistsExpr: 569 subq = whereExpr.Subquery 570 } 571 572 extractedSubq := semTable.SubqueryRef[subq] 573 assert.True(t, sqlparser.Equals.Expr(extractedSubq.Subquery, subq)) 574 assert.True(t, sqlparser.Equals.Expr(extractedSubq.Original, sel.Where.Expr)) 575 assert.EqualValues(t, tc.opCode, extractedSubq.OpCode) 576 if tc.otherSideName == "" { 577 assert.Nil(t, extractedSubq.OtherSide) 578 } else { 579 assert.True(t, sqlparser.Equals.Expr(extractedSubq.OtherSide, sqlparser.NewColName(tc.otherSideName))) 580 } 581 }) 582 } 583 } 584 585 func TestSubqueriesMappingSelectExprs(t *testing.T) { 586 tcs := []struct { 587 sql string 588 selExprIdx int 589 }{ 590 { 591 sql: "select (select id from t1)", 592 selExprIdx: 0, 593 }, 594 { 595 sql: "select id, (select id from t1) from t1", 596 selExprIdx: 1, 597 }, 598 } 599 600 for i, tc := range tcs { 601 t.Run(fmt.Sprintf("%d_%s", i+1, tc.sql), func(t *testing.T) { 602 stmt, semTable := parseAndAnalyze(t, tc.sql, "d") 603 sel, _ := stmt.(*sqlparser.Select) 604 605 subq := sel.SelectExprs[tc.selExprIdx].(*sqlparser.AliasedExpr).Expr.(*sqlparser.Subquery) 606 extractedSubq := semTable.SubqueryRef[subq] 607 assert.True(t, sqlparser.Equals.Expr(extractedSubq.Subquery, subq)) 608 assert.True(t, sqlparser.Equals.Expr(extractedSubq.Original, subq)) 609 assert.EqualValues(t, engine.PulloutValue, extractedSubq.OpCode) 610 }) 611 } 612 } 613 614 func TestSubqueryOrderByBinding(t *testing.T) { 615 queries := []struct { 616 query string 617 expected TableSet 618 }{{ 619 query: "select * from user u where exists (select * from user order by col)", 620 expected: T2, 621 }, { 622 query: "select * from user u where exists (select * from user order by user.col)", 623 expected: T2, 624 }, { 625 query: "select * from user u where exists (select * from user order by u.col)", 626 expected: T1, 627 }, { 628 query: "select * from dbName.user as u where exists (select * from dbName.user order by u.col)", 629 expected: T1, 630 }, { 631 query: "select * from dbName.user where exists (select * from otherDb.user order by dbName.user.col)", 632 expected: T1, 633 }, { 634 query: "select id from dbName.t1 where exists (select * from dbName.t2 order by dbName.t1.id)", 635 expected: T1, 636 }} 637 638 for _, tc := range queries { 639 t.Run(tc.query, func(t *testing.T) { 640 ast, err := sqlparser.Parse(tc.query) 641 require.NoError(t, err) 642 643 sel := ast.(*sqlparser.Select) 644 st, err := Analyze(sel, "dbName", fakeSchemaInfo()) 645 require.NoError(t, err) 646 exists := sel.Where.Expr.(*sqlparser.ExistsExpr) 647 expr := exists.Subquery.Select.(*sqlparser.Select).OrderBy[0].Expr 648 require.Equal(t, tc.expected, st.DirectDeps(expr)) 649 require.Equal(t, tc.expected, st.RecursiveDeps(expr)) 650 }) 651 } 652 } 653 654 func TestOrderByBindingTable(t *testing.T) { 655 tcases := []struct { 656 sql string 657 deps TableSet 658 }{{ 659 "select col from tabl order by col", 660 T1, 661 }, { 662 "select tabl.col from d.tabl order by col", 663 T1, 664 }, { 665 "select d.tabl.col from d.tabl order by col", 666 T1, 667 }, { 668 "select col from tabl order by tabl.col", 669 T1, 670 }, { 671 "select col from tabl order by d.tabl.col", 672 T1, 673 }, { 674 "select col from tabl order by 1", 675 T1, 676 }, { 677 "select col as c from tabl order by c", 678 T1, 679 }, { 680 "select 1 as c from tabl order by c", 681 T0, 682 }, { 683 "select name, name from t1, t2 order by name", 684 T2, 685 }, { 686 "(select id from t1) union (select uid from t2) order by id", 687 MergeTableSets(T1, T2), 688 }, { 689 "select id from t1 union (select uid from t2) order by 1", 690 MergeTableSets(T1, T2), 691 }, { 692 "select id from t1 union select uid from t2 union (select name from t) order by 1", 693 MergeTableSets(T1, T2, T3), 694 }, { 695 "select a.id from t1 as a union (select uid from t2) order by 1", 696 MergeTableSets(T1, T2), 697 }, { 698 "select b.id as a from t1 as b union (select uid as c from t2) order by 1", 699 MergeTableSets(T1, T2), 700 }, { 701 "select a.id from t1 as a union (select uid from t2, t union (select name from t) order by 1) order by 1", 702 MergeTableSets(T1, T2, T4), 703 }, { 704 "select a.id from t1 as a union (select uid from t2, t union (select name from t) order by 1) order by id", 705 MergeTableSets(T1, T2, T4), 706 }} 707 for _, tc := range tcases { 708 t.Run(tc.sql, func(t *testing.T) { 709 stmt, semTable := parseAndAnalyze(t, tc.sql, "d") 710 711 var order sqlparser.Expr 712 switch stmt := stmt.(type) { 713 case *sqlparser.Select: 714 order = stmt.OrderBy[0].Expr 715 case *sqlparser.Union: 716 order = stmt.OrderBy[0].Expr 717 default: 718 t.Fail() 719 } 720 d := semTable.RecursiveDeps(order) 721 require.Equal(t, tc.deps, d, tc.sql) 722 }) 723 } 724 } 725 726 func TestGroupByBinding(t *testing.T) { 727 tcases := []struct { 728 sql string 729 deps TableSet 730 }{{ 731 "select col from tabl group by col", 732 T1, 733 }, { 734 "select col from tabl group by tabl.col", 735 T1, 736 }, { 737 "select col from tabl group by d.tabl.col", 738 T1, 739 }, { 740 "select tabl.col as x from tabl group by x", 741 T1, 742 }, { 743 "select tabl.col as x from tabl group by col", 744 T1, 745 }, { 746 "select d.tabl.col as x from tabl group by x", 747 T1, 748 }, { 749 "select d.tabl.col as x from tabl group by col", 750 T1, 751 }, { 752 "select col from tabl group by 1", 753 T1, 754 }, { 755 "select col as c from tabl group by c", 756 T1, 757 }, { 758 "select 1 as c from tabl group by c", 759 T0, 760 }, { 761 "select t1.id from t1, t2 group by id", 762 T1, 763 }, { 764 "select id from t, t1 group by id", 765 T2, 766 }, { 767 "select id from t, t1 group by id", 768 T2, 769 }, { 770 "select a.id from t as a, t1 group by id", 771 T1, 772 }, { 773 "select a.id from t, t1 as a group by id", 774 T2, 775 }} 776 for _, tc := range tcases { 777 t.Run(tc.sql, func(t *testing.T) { 778 stmt, semTable := parseAndAnalyze(t, tc.sql, "d") 779 sel, _ := stmt.(*sqlparser.Select) 780 grp := sel.GroupBy[0] 781 d := semTable.RecursiveDeps(grp) 782 require.Equal(t, tc.deps, d, tc.sql) 783 }) 784 } 785 } 786 787 func TestHavingBinding(t *testing.T) { 788 tcases := []struct { 789 sql string 790 deps TableSet 791 }{{ 792 "select col from tabl having col = 1", 793 T1, 794 }, { 795 "select col from tabl having tabl.col = 1", 796 T1, 797 }, { 798 "select col from tabl having d.tabl.col = 1", 799 T1, 800 }, { 801 "select tabl.col as x from tabl having x = 1", 802 T1, 803 }, { 804 "select tabl.col as x from tabl having col", 805 T1, 806 }, { 807 "select col from tabl having 1 = 1", 808 T0, 809 }, { 810 "select col as c from tabl having c = 1", 811 T1, 812 }, { 813 "select 1 as c from tabl having c = 1", 814 T0, 815 }, { 816 "select t1.id from t1, t2 having id = 1", 817 T1, 818 }, { 819 "select t.id from t, t1 having id = 1", 820 T1, 821 }, { 822 "select t.id, count(*) as a from t, t1 group by t.id having a = 1", 823 MergeTableSets(T1, T2), 824 }, { 825 "select t.id, sum(t2.name) as a from t, t2 group by t.id having a = 1", 826 T2, 827 }, { 828 sql: "select u2.a, u1.a from u1, u2 having u2.a = 2", 829 deps: T2, 830 }} 831 for _, tc := range tcases { 832 t.Run(tc.sql, func(t *testing.T) { 833 stmt, semTable := parseAndAnalyze(t, tc.sql, "d") 834 sel, _ := stmt.(*sqlparser.Select) 835 hvng := sel.Having.Expr 836 d := semTable.RecursiveDeps(hvng) 837 require.Equal(t, tc.deps, d, tc.sql) 838 }) 839 } 840 } 841 842 func TestUnionCheckFirstAndLastSelectsDeps(t *testing.T) { 843 query := "select col1 from tabl1 union select col2 from tabl2" 844 845 stmt, semTable := parseAndAnalyze(t, query, "") 846 union, _ := stmt.(*sqlparser.Union) 847 sel1 := union.Left.(*sqlparser.Select) 848 sel2 := union.Right.(*sqlparser.Select) 849 850 t1 := sel1.From[0].(*sqlparser.AliasedTableExpr) 851 t2 := sel2.From[0].(*sqlparser.AliasedTableExpr) 852 ts1 := semTable.TableSetFor(t1) 853 ts2 := semTable.TableSetFor(t2) 854 assert.Equal(t, SingleTableSet(0), ts1) 855 assert.Equal(t, SingleTableSet(1), ts2) 856 857 d1 := semTable.RecursiveDeps(extract(sel1, 0)) 858 d2 := semTable.RecursiveDeps(extract(sel2, 0)) 859 assert.Equal(t, T1, d1) 860 assert.Equal(t, T2, d2) 861 } 862 863 func TestUnionOrderByRewrite(t *testing.T) { 864 query := "select tabl1.id from tabl1 union select 1 order by 1" 865 866 stmt, _ := parseAndAnalyze(t, query, "") 867 assert.Equal(t, "select tabl1.id from tabl1 union select 1 from dual order by id asc", sqlparser.String(stmt)) 868 } 869 870 func TestInvalidQueries(t *testing.T) { 871 tcases := []struct { 872 sql string 873 err string 874 shardedErr string 875 }{{ 876 sql: "select t1.id, t1.col1 from t1 union select t2.uid from t2", 877 err: "The used SELECT statements have a different number of columns", 878 }, { 879 sql: "select t1.id from t1 union select t2.uid, t2.price from t2", 880 err: "The used SELECT statements have a different number of columns", 881 }, { 882 sql: "select t1.id from t1 union select t2.uid, t2.price from t2", 883 err: "The used SELECT statements have a different number of columns", 884 }, { 885 sql: "(select 1,2 union select 3,4) union (select 5,6 union select 7)", 886 err: "The used SELECT statements have a different number of columns", 887 }, { 888 sql: "select id from a union select 3 order by a.id", 889 err: "Table a from one of the SELECTs cannot be used in global ORDER clause", 890 }, { 891 sql: "select a.id, b.id from a, b union select 1, 2 order by id", 892 err: "Column 'id' in field list is ambiguous", 893 }, { 894 sql: "select sql_calc_found_rows id from a union select 1 limit 109", 895 err: "VT12001: unsupported: SQL_CALC_FOUND_ROWS not supported with union", 896 }, { 897 sql: "select * from (select sql_calc_found_rows id from a) as t", 898 err: "Incorrect usage/placement of 'SQL_CALC_FOUND_ROWS'", 899 }, { 900 sql: "select (select sql_calc_found_rows id from a) as t", 901 err: "Incorrect usage/placement of 'SQL_CALC_FOUND_ROWS'", 902 }, { 903 sql: "select id from t1 natural join t2", 904 err: "VT12001: unsupported: natural join", 905 }, { 906 sql: "select * from music where user_id IN (select sql_calc_found_rows * from music limit 10)", 907 err: "Incorrect usage/placement of 'SQL_CALC_FOUND_ROWS'", 908 }, { 909 sql: "select is_free_lock('xyz') from user", 910 err: "is_free_lock('xyz') allowed only with dual", 911 }, { 912 sql: "SELECT * FROM JSON_TABLE('[ {\"c1\": null} ]','$[*]' COLUMNS( c1 INT PATH '$.c1' ERROR ON ERROR )) as jt", 913 err: "VT12001: unsupported: json_table expressions", 914 }, { 915 sql: "select does_not_exist from t1", 916 shardedErr: "symbol does_not_exist not found", 917 }, { 918 sql: "select t1.does_not_exist from t1, t2", 919 shardedErr: "symbol t1.does_not_exist not found", 920 }} 921 for _, tc := range tcases { 922 t.Run(tc.sql, func(t *testing.T) { 923 parse, err := sqlparser.Parse(tc.sql) 924 require.NoError(t, err) 925 926 st, err := Analyze(parse.(sqlparser.SelectStatement), "dbName", fakeSchemaInfo()) 927 if tc.err != "" { 928 require.EqualError(t, err, tc.err) 929 } else { 930 require.NoError(t, err, tc.err) 931 require.EqualError(t, st.NotUnshardedErr, tc.shardedErr) 932 } 933 }) 934 } 935 } 936 937 func TestUnionWithOrderBy(t *testing.T) { 938 query := "select col1 from tabl1 union (select col2 from tabl2) order by 1" 939 940 stmt, semTable := parseAndAnalyze(t, query, "") 941 union, _ := stmt.(*sqlparser.Union) 942 sel1 := sqlparser.GetFirstSelect(union) 943 sel2 := sqlparser.GetFirstSelect(union.Right) 944 945 t1 := sel1.From[0].(*sqlparser.AliasedTableExpr) 946 t2 := sel2.From[0].(*sqlparser.AliasedTableExpr) 947 ts1 := semTable.TableSetFor(t1) 948 ts2 := semTable.TableSetFor(t2) 949 assert.Equal(t, SingleTableSet(0), ts1) 950 assert.Equal(t, SingleTableSet(1), ts2) 951 952 d1 := semTable.RecursiveDeps(extract(sel1, 0)) 953 d2 := semTable.RecursiveDeps(extract(sel2, 0)) 954 assert.Equal(t, T1, d1) 955 assert.Equal(t, T2, d2) 956 } 957 958 func TestScopingWDerivedTables(t *testing.T) { 959 queries := []struct { 960 query string 961 errorMessage string 962 recursiveExpectation TableSet 963 expectation TableSet 964 }{ 965 { 966 query: "select id from (select x as id from user) as t", 967 recursiveExpectation: T1, 968 expectation: T2, 969 }, { 970 query: "select id from (select foo as id from user) as t", 971 recursiveExpectation: T1, 972 expectation: T2, 973 }, { 974 query: "select id from (select foo as id from (select x as foo from user) as c) as t", 975 recursiveExpectation: T1, 976 expectation: T3, 977 }, { 978 query: "select t.id from (select foo as id from user) as t", 979 recursiveExpectation: T1, 980 expectation: T2, 981 }, { 982 query: "select t.id2 from (select foo as id from user) as t", 983 errorMessage: "symbol t.id2 not found", 984 }, { 985 query: "select id from (select 42 as id) as t", 986 recursiveExpectation: T0, 987 expectation: T2, 988 }, { 989 query: "select t.id from (select 42 as id) as t", 990 recursiveExpectation: T0, 991 expectation: T2, 992 }, { 993 query: "select ks.t.id from (select 42 as id) as t", 994 errorMessage: "symbol ks.t.id not found", 995 }, { 996 query: "select * from (select id, id from user) as t", 997 errorMessage: "Duplicate column name 'id'", 998 }, { 999 query: "select t.baz = 1 from (select id as baz from user) as t", 1000 expectation: T2, 1001 recursiveExpectation: T1, 1002 }, { 1003 query: "select t.id from (select * from user, music) as t", 1004 expectation: T3, 1005 recursiveExpectation: MergeTableSets(T1, T2), 1006 }, { 1007 query: "select t.id from (select * from user, music) as t order by t.id", 1008 expectation: T3, 1009 recursiveExpectation: MergeTableSets(T1, T2), 1010 }, { 1011 query: "select t.id from (select * from user) as t join user as u on t.id = u.id", 1012 expectation: T2, 1013 recursiveExpectation: T1, 1014 }, { 1015 query: "select t.col1 from t3 ua join (select t1.id, t1.col1 from t1 join t2) as t", 1016 expectation: T4, 1017 recursiveExpectation: T2, 1018 }, { 1019 query: "select uu.test from (select id from t1) uu", 1020 errorMessage: "symbol uu.test not found", 1021 }, { 1022 query: "select uu.id from (select id as col from t1) uu", 1023 errorMessage: "symbol uu.id not found", 1024 }, { 1025 query: "select uu.id from (select id as col from t1) uu", 1026 errorMessage: "symbol uu.id not found", 1027 }, { 1028 query: "select uu.id from (select id from t1) as uu where exists (select * from t2 as uu where uu.id = uu.uid)", 1029 expectation: T2, 1030 recursiveExpectation: T1, 1031 }, { 1032 query: "select 1 from user uu where exists (select 1 from user where exists (select 1 from (select 1 from t1) uu where uu.user_id = uu.id))", 1033 expectation: T0, 1034 recursiveExpectation: T0, 1035 }} 1036 for _, query := range queries { 1037 t.Run(query.query, func(t *testing.T) { 1038 parse, err := sqlparser.Parse(query.query) 1039 require.NoError(t, err) 1040 st, err := Analyze(parse.(sqlparser.SelectStatement), "user", &FakeSI{ 1041 Tables: map[string]*vindexes.Table{ 1042 "t": {Name: sqlparser.NewIdentifierCS("t")}, 1043 }, 1044 }) 1045 1046 switch { 1047 case query.errorMessage != "" && err != nil: 1048 require.EqualError(t, err, query.errorMessage) 1049 case query.errorMessage != "": 1050 require.EqualError(t, st.NotUnshardedErr, query.errorMessage) 1051 default: 1052 require.NoError(t, err) 1053 sel := parse.(*sqlparser.Select) 1054 assert.Equal(t, query.recursiveExpectation, st.RecursiveDeps(extract(sel, 0)), "RecursiveDeps") 1055 assert.Equal(t, query.expectation, st.DirectDeps(extract(sel, 0)), "DirectDeps") 1056 } 1057 }) 1058 } 1059 } 1060 1061 func TestDerivedTablesOrderClause(t *testing.T) { 1062 queries := []struct { 1063 query string 1064 recursiveExpectation TableSet 1065 expectation TableSet 1066 }{{ 1067 query: "select 1 from (select id from user) as t order by id", 1068 recursiveExpectation: T1, 1069 expectation: T2, 1070 }, { 1071 query: "select id from (select id from user) as t order by id", 1072 recursiveExpectation: T1, 1073 expectation: T2, 1074 }, { 1075 query: "select id from (select id from user) as t order by t.id", 1076 recursiveExpectation: T1, 1077 expectation: T2, 1078 }, { 1079 query: "select id as foo from (select id from user) as t order by foo", 1080 recursiveExpectation: T1, 1081 expectation: T2, 1082 }, { 1083 query: "select bar from (select id as bar from user) as t order by bar", 1084 recursiveExpectation: T1, 1085 expectation: T2, 1086 }, { 1087 query: "select bar as foo from (select id as bar from user) as t order by bar", 1088 recursiveExpectation: T1, 1089 expectation: T2, 1090 }, { 1091 query: "select bar as foo from (select id as bar from user) as t order by foo", 1092 recursiveExpectation: T1, 1093 expectation: T2, 1094 }, { 1095 query: "select bar as foo from (select id as bar, oo from user) as t order by oo", 1096 recursiveExpectation: T1, 1097 expectation: T2, 1098 }, { 1099 query: "select bar as foo from (select id, oo from user) as t(bar,oo) order by bar", 1100 recursiveExpectation: T1, 1101 expectation: T2, 1102 }} 1103 si := &FakeSI{Tables: map[string]*vindexes.Table{"t": {Name: sqlparser.NewIdentifierCS("t")}}} 1104 for _, query := range queries { 1105 t.Run(query.query, func(t *testing.T) { 1106 parse, err := sqlparser.Parse(query.query) 1107 require.NoError(t, err) 1108 1109 st, err := Analyze(parse.(sqlparser.SelectStatement), "user", si) 1110 require.NoError(t, err) 1111 1112 sel := parse.(*sqlparser.Select) 1113 assert.Equal(t, query.recursiveExpectation, st.RecursiveDeps(sel.OrderBy[0].Expr), "RecursiveDeps") 1114 assert.Equal(t, query.expectation, st.DirectDeps(sel.OrderBy[0].Expr), "DirectDeps") 1115 1116 }) 1117 } 1118 } 1119 1120 func TestScopingWComplexDerivedTables(t *testing.T) { 1121 queries := []struct { 1122 query string 1123 errorMessage string 1124 rightExpectation TableSet 1125 leftExpectation TableSet 1126 }{ 1127 { 1128 query: "select 1 from user uu where exists (select 1 from user where exists (select 1 from (select 1 from t1) uu where uu.user_id = uu.id))", 1129 rightExpectation: T1, 1130 leftExpectation: T1, 1131 }, 1132 { 1133 query: "select 1 from user.user uu where exists (select 1 from user.user as uu where exists (select 1 from (select 1 from user.t1) uu where uu.user_id = uu.id))", 1134 rightExpectation: T2, 1135 leftExpectation: T2, 1136 }, 1137 } 1138 for _, query := range queries { 1139 t.Run(query.query, func(t *testing.T) { 1140 parse, err := sqlparser.Parse(query.query) 1141 require.NoError(t, err) 1142 st, err := Analyze(parse.(sqlparser.SelectStatement), "user", &FakeSI{ 1143 Tables: map[string]*vindexes.Table{ 1144 "t": {Name: sqlparser.NewIdentifierCS("t")}, 1145 }, 1146 }) 1147 if query.errorMessage != "" { 1148 require.EqualError(t, err, query.errorMessage) 1149 } else { 1150 require.NoError(t, err) 1151 sel := parse.(*sqlparser.Select) 1152 comparisonExpr := sel.Where.Expr.(*sqlparser.ExistsExpr).Subquery.Select.(*sqlparser.Select).Where.Expr.(*sqlparser.ExistsExpr).Subquery.Select.(*sqlparser.Select).Where.Expr.(*sqlparser.ComparisonExpr) 1153 left := comparisonExpr.Left 1154 right := comparisonExpr.Right 1155 assert.Equal(t, query.leftExpectation, st.RecursiveDeps(left), "Left RecursiveDeps") 1156 assert.Equal(t, query.rightExpectation, st.RecursiveDeps(right), "Right RecursiveDeps") 1157 } 1158 }) 1159 } 1160 } 1161 1162 func TestScopingWVindexTables(t *testing.T) { 1163 queries := []struct { 1164 query string 1165 errorMessage string 1166 recursiveExpectation TableSet 1167 expectation TableSet 1168 }{ 1169 { 1170 query: "select id from user_index where id = 1", 1171 recursiveExpectation: T1, 1172 expectation: T1, 1173 }, { 1174 query: "select u.id + t.id from t as t join user_index as u where u.id = 1 and u.id = t.id", 1175 recursiveExpectation: MergeTableSets(T1, T2), 1176 expectation: MergeTableSets(T1, T2), 1177 }, 1178 } 1179 for _, query := range queries { 1180 t.Run(query.query, func(t *testing.T) { 1181 parse, err := sqlparser.Parse(query.query) 1182 require.NoError(t, err) 1183 hash, _ := vindexes.NewHash("user_index", nil) 1184 st, err := Analyze(parse.(sqlparser.SelectStatement), "user", &FakeSI{ 1185 Tables: map[string]*vindexes.Table{ 1186 "t": {Name: sqlparser.NewIdentifierCS("t")}, 1187 }, 1188 VindexTables: map[string]vindexes.Vindex{ 1189 "user_index": hash, 1190 }, 1191 }) 1192 if query.errorMessage != "" { 1193 require.EqualError(t, err, query.errorMessage) 1194 } else { 1195 require.NoError(t, err) 1196 sel := parse.(*sqlparser.Select) 1197 assert.Equal(t, query.recursiveExpectation, st.RecursiveDeps(extract(sel, 0))) 1198 assert.Equal(t, query.expectation, st.DirectDeps(extract(sel, 0))) 1199 } 1200 }) 1201 } 1202 } 1203 1204 func BenchmarkAnalyzeMultipleDifferentQueries(b *testing.B) { 1205 queries := []string{ 1206 "select col from tabl", 1207 "select t.col from t, s", 1208 "select max(tabl.col1 + tabl.col2) from d.X as tabl", 1209 "select max(X.col + S.col) from t as X, s as S", 1210 "select case t.col when s.col then r.col else u.col end from t, s, r, w, u", 1211 "select t.col1, (select t.col2 from z as t) from x as t", 1212 "select * from user u where exists (select * from user order by col)", 1213 "select id from dbName.t1 where exists (select * from dbName.t2 order by dbName.t1.id)", 1214 "select d.tabl.col from d.tabl order by col", 1215 "select a.id from t1 as a union (select uid from t2, t union (select name from t) order by 1) order by 1", 1216 "select a.id from t, t1 as a group by id", 1217 "select tabl.col as x from tabl having x = 1", 1218 "select id from (select foo as id from (select x as foo from user) as c) as t", 1219 } 1220 1221 for i := 0; i < b.N; i++ { 1222 for _, query := range queries { 1223 parse, err := sqlparser.Parse(query) 1224 require.NoError(b, err) 1225 1226 _, _ = Analyze(parse.(sqlparser.SelectStatement), "d", fakeSchemaInfo()) 1227 } 1228 } 1229 } 1230 1231 func BenchmarkAnalyzeUnionQueries(b *testing.B) { 1232 queries := []string{ 1233 "select id from t1 union select uid from t2", 1234 "select col1 from tabl1 union (select col2 from tabl2)", 1235 "select t1.id, t1.col1 from t1 union select t2.uid from t2", 1236 "select a.id from t1 as a union (select uid from t2, t union (select name from t) order by 1) order by 1", 1237 "select b.id as a from t1 as b union (select uid as c from t2) order by 1", 1238 "select a.id from t1 as a union (select uid from t2) order by 1", 1239 "select id from t1 union select uid from t2 union (select name from t)", 1240 "select id from t1 union (select uid from t2) order by 1", 1241 "(select id from t1) union (select uid from t2) order by id", 1242 "select a.id from t1 as a union (select uid from t2, t union (select name from t) order by 1) order by 1", 1243 } 1244 1245 for i := 0; i < b.N; i++ { 1246 for _, query := range queries { 1247 parse, err := sqlparser.Parse(query) 1248 require.NoError(b, err) 1249 1250 _, _ = Analyze(parse.(sqlparser.SelectStatement), "d", fakeSchemaInfo()) 1251 } 1252 } 1253 } 1254 1255 func BenchmarkAnalyzeSubQueries(b *testing.B) { 1256 queries := []string{ 1257 "select * from user u where exists (select * from user order by col)", 1258 "select * from user u where exists (select * from user order by user.col)", 1259 "select * from user u where exists (select * from user order by u.col)", 1260 "select * from dbName.user as u where exists (select * from dbName.user order by u.col)", 1261 "select * from dbName.user where exists (select * from otherDb.user order by dbName.user.col)", 1262 "select id from dbName.t1 where exists (select * from dbName.t2 order by dbName.t1.id)", 1263 "select t.col1, (select t.col2 from z as t) from x as t", 1264 "select t.col1, (select t.col2 from z) from x as t", 1265 "select t.col1, (select (select z.col2 from y) from z) from x as t", 1266 "select t.col1, (select (select y.col2 from y) from z) from x as t", 1267 "select t.col1, (select (select (select (select w.col2 from w) from x) from y) from z) from x as t", 1268 "select t.col1, (select id from t) from x as t", 1269 } 1270 1271 for i := 0; i < b.N; i++ { 1272 for _, query := range queries { 1273 parse, err := sqlparser.Parse(query) 1274 require.NoError(b, err) 1275 1276 _, _ = Analyze(parse.(sqlparser.SelectStatement), "d", fakeSchemaInfo()) 1277 } 1278 } 1279 } 1280 1281 func BenchmarkAnalyzeDerivedTableQueries(b *testing.B) { 1282 queries := []string{ 1283 "select id from (select x as id from user) as t", 1284 "select id from (select foo as id from user) as t", 1285 "select id from (select foo as id from (select x as foo from user) as c) as t", 1286 "select t.id from (select foo as id from user) as t", 1287 "select t.id2 from (select foo as id from user) as t", 1288 "select id from (select 42 as id) as t", 1289 "select t.id from (select 42 as id) as t", 1290 "select ks.t.id from (select 42 as id) as t", 1291 "select * from (select id, id from user) as t", 1292 "select t.baz = 1 from (select id as baz from user) as t", 1293 "select t.id from (select * from user, music) as t", 1294 "select t.id from (select * from user, music) as t order by t.id", 1295 "select t.id from (select * from user) as t join user as u on t.id = u.id", 1296 "select t.col1 from t3 ua join (select t1.id, t1.col1 from t1 join t2) as t", 1297 "select uu.id from (select id from t1) as uu where exists (select * from t2 as uu where uu.id = uu.uid)", 1298 "select 1 from user uu where exists (select 1 from user where exists (select 1 from (select 1 from t1) uu where uu.user_id = uu.id))", 1299 } 1300 1301 for i := 0; i < b.N; i++ { 1302 for _, query := range queries { 1303 parse, err := sqlparser.Parse(query) 1304 require.NoError(b, err) 1305 1306 _, _ = Analyze(parse.(sqlparser.SelectStatement), "d", fakeSchemaInfo()) 1307 } 1308 } 1309 } 1310 1311 func BenchmarkAnalyzeHavingQueries(b *testing.B) { 1312 queries := []string{ 1313 "select col from tabl having col = 1", 1314 "select col from tabl having tabl.col = 1", 1315 "select col from tabl having d.tabl.col = 1", 1316 "select tabl.col as x from tabl having x = 1", 1317 "select tabl.col as x from tabl having col", 1318 "select col from tabl having 1 = 1", 1319 "select col as c from tabl having c = 1", 1320 "select 1 as c from tabl having c = 1", 1321 "select t1.id from t1, t2 having id = 1", 1322 "select t.id from t, t1 having id = 1", 1323 "select t.id, count(*) as a from t, t1 group by t.id having a = 1", 1324 "select u2.a, u1.a from u1, u2 having u2.a = 2", 1325 } 1326 1327 for i := 0; i < b.N; i++ { 1328 for _, query := range queries { 1329 parse, err := sqlparser.Parse(query) 1330 require.NoError(b, err) 1331 1332 _, _ = Analyze(parse.(sqlparser.SelectStatement), "d", fakeSchemaInfo()) 1333 } 1334 } 1335 } 1336 1337 func BenchmarkAnalyzeGroupByQueries(b *testing.B) { 1338 queries := []string{ 1339 "select col from tabl group by col", 1340 "select col from tabl group by tabl.col", 1341 "select col from tabl group by d.tabl.col", 1342 "select tabl.col as x from tabl group by x", 1343 "select tabl.col as x from tabl group by col", 1344 "select d.tabl.col as x from tabl group by x", 1345 "select d.tabl.col as x from tabl group by col", 1346 "select col from tabl group by 1", 1347 "select col as c from tabl group by c", 1348 "select 1 as c from tabl group by c", 1349 "select t1.id from t1, t2 group by id", 1350 "select id from t, t1 group by id", 1351 "select id from t, t1 group by id", 1352 "select a.id from t as a, t1 group by id", 1353 "select a.id from t, t1 as a group by id", 1354 } 1355 1356 for i := 0; i < b.N; i++ { 1357 for _, query := range queries { 1358 parse, err := sqlparser.Parse(query) 1359 require.NoError(b, err) 1360 1361 _, _ = Analyze(parse.(sqlparser.SelectStatement), "d", fakeSchemaInfo()) 1362 } 1363 } 1364 } 1365 1366 func BenchmarkAnalyzeOrderByQueries(b *testing.B) { 1367 queries := []string{ 1368 "select col from tabl order by col", 1369 "select tabl.col from d.tabl order by col", 1370 "select d.tabl.col from d.tabl order by col", 1371 "select col from tabl order by tabl.col", 1372 "select col from tabl order by d.tabl.col", 1373 "select col from tabl order by 1", 1374 "select col as c from tabl order by c", 1375 "select 1 as c from tabl order by c", 1376 "select name, name from t1, t2 order by name", 1377 } 1378 1379 for i := 0; i < b.N; i++ { 1380 for _, query := range queries { 1381 parse, err := sqlparser.Parse(query) 1382 require.NoError(b, err) 1383 1384 _, _ = Analyze(parse.(sqlparser.SelectStatement), "d", fakeSchemaInfo()) 1385 } 1386 } 1387 } 1388 1389 func parseAndAnalyze(t *testing.T, query, dbName string) (sqlparser.Statement, *SemTable) { 1390 t.Helper() 1391 parse, err := sqlparser.Parse(query) 1392 require.NoError(t, err) 1393 1394 semTable, err := Analyze(parse, dbName, fakeSchemaInfo()) 1395 require.NoError(t, err) 1396 return parse, semTable 1397 } 1398 1399 func TestSingleUnshardedKeyspace(t *testing.T) { 1400 tests := []struct { 1401 query string 1402 unsharded *vindexes.Keyspace 1403 tables []*vindexes.Table 1404 }{ 1405 { 1406 query: "select 1 from t, t1", 1407 unsharded: nil, // both tables are unsharded, but from different keyspaces 1408 tables: nil, 1409 }, { 1410 query: "select 1 from t2", 1411 unsharded: nil, 1412 tables: nil, 1413 }, { 1414 query: "select 1 from t, t2", 1415 unsharded: nil, 1416 tables: nil, 1417 }, { 1418 query: "select 1 from t as A, t as B", 1419 unsharded: ks1, 1420 tables: []*vindexes.Table{ 1421 {Keyspace: ks1, Name: sqlparser.NewIdentifierCS("t")}, 1422 {Keyspace: ks1, Name: sqlparser.NewIdentifierCS("t")}, 1423 }, 1424 }, 1425 } 1426 1427 for _, test := range tests { 1428 t.Run(test.query, func(t *testing.T) { 1429 _, semTable := parseAndAnalyze(t, test.query, "d") 1430 queryIsUnsharded, tables := semTable.SingleUnshardedKeyspace() 1431 assert.Equal(t, test.unsharded, queryIsUnsharded) 1432 assert.Equal(t, test.tables, tables) 1433 }) 1434 } 1435 } 1436 1437 // TestScopingSubQueryJoinClause tests the scoping behavior of a subquery containing a join clause. 1438 // The test ensures that the scoping analysis correctly identifies and handles the relationships 1439 // between the tables involved in the join operation with the outer query. 1440 func TestScopingSubQueryJoinClause(t *testing.T) { 1441 query := "select (select 1 from u1 join u2 on u1.id = u2.id and u2.id = u3.id) x from u3" 1442 1443 parse, err := sqlparser.Parse(query) 1444 require.NoError(t, err) 1445 1446 st, err := Analyze(parse, "user", &FakeSI{ 1447 Tables: map[string]*vindexes.Table{ 1448 "t": {Name: sqlparser.NewIdentifierCS("t")}, 1449 }, 1450 }) 1451 require.NoError(t, err) 1452 require.NoError(t, st.NotUnshardedErr) 1453 1454 tb := st.DirectDeps(parse.(*sqlparser.Select).SelectExprs[0].(*sqlparser.AliasedExpr).Expr.(*sqlparser.Subquery).Select.(*sqlparser.Select).From[0].(*sqlparser.JoinTableExpr).Condition.On) 1455 require.Equal(t, 3, tb.NumberOfTables()) 1456 1457 } 1458 1459 var ks1 = &vindexes.Keyspace{ 1460 Name: "ks1", 1461 Sharded: false, 1462 } 1463 var ks2 = &vindexes.Keyspace{ 1464 Name: "ks2", 1465 Sharded: false, 1466 } 1467 var ks3 = &vindexes.Keyspace{ 1468 Name: "ks3", 1469 Sharded: true, 1470 } 1471 1472 func fakeSchemaInfo() *FakeSI { 1473 cols1 := []vindexes.Column{{ 1474 Name: sqlparser.NewIdentifierCI("id"), 1475 Type: querypb.Type_INT64, 1476 }} 1477 cols2 := []vindexes.Column{{ 1478 Name: sqlparser.NewIdentifierCI("uid"), 1479 Type: querypb.Type_INT64, 1480 }, { 1481 Name: sqlparser.NewIdentifierCI("name"), 1482 Type: querypb.Type_VARCHAR, 1483 }} 1484 1485 si := &FakeSI{ 1486 Tables: map[string]*vindexes.Table{ 1487 "t": {Name: sqlparser.NewIdentifierCS("t"), Keyspace: ks1}, 1488 "t1": {Name: sqlparser.NewIdentifierCS("t1"), Columns: cols1, ColumnListAuthoritative: true, Keyspace: ks2}, 1489 "t2": {Name: sqlparser.NewIdentifierCS("t2"), Columns: cols2, ColumnListAuthoritative: true, Keyspace: ks3}, 1490 }, 1491 } 1492 return si 1493 }