vitess.io/vitess@v0.16.2/go/vt/vtgate/semantics/early_rewriter_test.go (about) 1 /* 2 Copyright 2021 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package semantics 18 19 import ( 20 "testing" 21 22 "github.com/stretchr/testify/assert" 23 "github.com/stretchr/testify/require" 24 25 "vitess.io/vitess/go/sqltypes" 26 "vitess.io/vitess/go/vt/sqlparser" 27 "vitess.io/vitess/go/vt/vtgate/vindexes" 28 ) 29 30 func TestExpandStar(t *testing.T) { 31 ks := &vindexes.Keyspace{ 32 Name: "main", 33 Sharded: false, 34 } 35 schemaInfo := &FakeSI{ 36 Tables: map[string]*vindexes.Table{ 37 "t1": { 38 Keyspace: ks, 39 Name: sqlparser.NewIdentifierCS("t1"), 40 Columns: []vindexes.Column{{ 41 Name: sqlparser.NewIdentifierCI("a"), 42 Type: sqltypes.VarChar, 43 }, { 44 Name: sqlparser.NewIdentifierCI("b"), 45 Type: sqltypes.VarChar, 46 }, { 47 Name: sqlparser.NewIdentifierCI("c"), 48 Type: sqltypes.VarChar, 49 }}, 50 ColumnListAuthoritative: true, 51 }, 52 "t2": { 53 Keyspace: ks, 54 Name: sqlparser.NewIdentifierCS("t2"), 55 Columns: []vindexes.Column{{ 56 Name: sqlparser.NewIdentifierCI("c1"), 57 Type: sqltypes.VarChar, 58 }, { 59 Name: sqlparser.NewIdentifierCI("c2"), 60 Type: sqltypes.VarChar, 61 }}, 62 ColumnListAuthoritative: true, 63 }, 64 "t3": { // non authoritative table. 65 Keyspace: ks, 66 Name: sqlparser.NewIdentifierCS("t3"), 67 Columns: []vindexes.Column{{ 68 Name: sqlparser.NewIdentifierCI("col"), 69 Type: sqltypes.VarChar, 70 }}, 71 ColumnListAuthoritative: false, 72 }, 73 "t4": { 74 Keyspace: ks, 75 Name: sqlparser.NewIdentifierCS("t4"), 76 Columns: []vindexes.Column{{ 77 Name: sqlparser.NewIdentifierCI("c1"), 78 Type: sqltypes.VarChar, 79 }, { 80 Name: sqlparser.NewIdentifierCI("c4"), 81 Type: sqltypes.VarChar, 82 }}, 83 ColumnListAuthoritative: true, 84 }, 85 "t5": { 86 Keyspace: ks, 87 Name: sqlparser.NewIdentifierCS("t5"), 88 Columns: []vindexes.Column{{ 89 Name: sqlparser.NewIdentifierCI("a"), 90 Type: sqltypes.VarChar, 91 }, { 92 Name: sqlparser.NewIdentifierCI("b"), 93 Type: sqltypes.VarChar, 94 }}, 95 ColumnListAuthoritative: true, 96 }, 97 }, 98 } 99 cDB := "db" 100 tcases := []struct { 101 sql string 102 expSQL string 103 expErr string 104 colExpandedNumber int 105 }{{ 106 sql: "select * from t1", 107 expSQL: "select a, b, c from t1", 108 }, { 109 sql: "select t1.* from t1", 110 expSQL: "select a, b, c from t1", 111 }, { 112 sql: "select *, 42, t1.* from t1", 113 expSQL: "select a, b, c, 42, a, b, c from t1", 114 colExpandedNumber: 6, 115 }, { 116 sql: "select 42, t1.* from t1", 117 expSQL: "select 42, a, b, c from t1", 118 }, { 119 sql: "select * from t1, t2", 120 expSQL: "select t1.a as a, t1.b as b, t1.c as c, t2.c1 as c1, t2.c2 as c2 from t1, t2", 121 }, { 122 sql: "select t1.* from t1, t2", 123 expSQL: "select t1.a as a, t1.b as b, t1.c as c from t1, t2", 124 }, { 125 sql: "select *, t1.* from t1, t2", 126 expSQL: "select t1.a as a, t1.b as b, t1.c as c, t2.c1 as c1, t2.c2 as c2, t1.a as a, t1.b as b, t1.c as c from t1, t2", 127 colExpandedNumber: 6, 128 }, { // aliased table 129 sql: "select * from t1 a, t2 b", 130 expSQL: "select a.a as a, a.b as b, a.c as c, b.c1 as c1, b.c2 as c2 from t1 as a, t2 as b", 131 }, { // t3 is non-authoritative table 132 sql: "select * from t3", 133 expSQL: "select * from t3", 134 }, { // t3 is non-authoritative table 135 sql: "select * from t1, t2, t3", 136 expSQL: "select * from t1, t2, t3", 137 }, { // t3 is non-authoritative table 138 sql: "select t1.*, t2.*, t3.* from t1, t2, t3", 139 expSQL: "select t1.a as a, t1.b as b, t1.c as c, t2.c1 as c1, t2.c2 as c2, t3.* from t1, t2, t3", 140 }, { 141 sql: "select foo.* from t1, t2", 142 expErr: "Unknown table 'foo'", 143 }, { 144 sql: "select * from t1 join t2 on t1.a = t2.c1", 145 expSQL: "select t1.a as a, t1.b as b, t1.c as c, t2.c1 as c1, t2.c2 as c2 from t1 join t2 on t1.a = t2.c1", 146 }, { 147 sql: "select * from t2 join t4 using (c1)", 148 expSQL: "select t2.c1 as c1, t2.c2 as c2, t4.c4 as c4 from t2 join t4 where t2.c1 = t4.c1", 149 }, { 150 sql: "select * from t2 join t4 using (c1) join t2 as X using (c1)", 151 expSQL: "select t2.c1 as c1, t2.c2 as c2, t4.c4 as c4, X.c2 as c2 from t2 join t4 join t2 as X where t2.c1 = t4.c1 and t2.c1 = X.c1 and t4.c1 = X.c1", 152 }, { 153 sql: "select * from t2 join t4 using (c1), t2 as t2b join t4 as t4b using (c1)", 154 expSQL: "select t2.c1 as c1, t2.c2 as c2, t4.c4 as c4, t2b.c1 as c1, t2b.c2 as c2, t4b.c4 as c4 from t2 join t4, t2 as t2b join t4 as t4b where t2b.c1 = t4b.c1 and t2.c1 = t4.c1", 155 }, { 156 sql: "select * from t1 join t5 using (b)", 157 expSQL: "select t1.b as b, t1.a as a, t1.c as c, t5.a as a from t1 join t5 where t1.b = t5.b", 158 }, { 159 sql: "select * from t1 join t5 using (b) having b = 12", 160 expSQL: "select t1.b as b, t1.a as a, t1.c as c, t5.a as a from t1 join t5 where t1.b = t5.b having b = 12", 161 }, { 162 sql: "select 1 from t1 join t5 using (b) having b = 12", 163 expSQL: "select 1 from t1 join t5 where t1.b = t5.b having t1.b = 12", 164 }, { 165 sql: "select * from (select 12) as t", 166 expSQL: "select t.`12` from (select 12 from dual) as t", 167 }, { 168 sql: "SELECT * FROM (SELECT *, 12 AS foo FROM t3) as results", 169 expSQL: "select * from (select *, 12 as foo from t3) as results", 170 }, { 171 // if we are only star-expanding authoritative tables, we don't need to stop the expansion 172 sql: "SELECT * FROM (SELECT t2.*, 12 AS foo FROM t3, t2) as results", 173 expSQL: "select results.c1, results.c2, results.foo from (select t2.c1 as c1, t2.c2 as c2, 12 as foo from t3, t2) as results", 174 }} 175 for _, tcase := range tcases { 176 t.Run(tcase.sql, func(t *testing.T) { 177 ast, err := sqlparser.Parse(tcase.sql) 178 require.NoError(t, err) 179 selectStatement, isSelectStatement := ast.(*sqlparser.Select) 180 require.True(t, isSelectStatement, "analyzer expects a select statement") 181 st, err := Analyze(selectStatement, cDB, schemaInfo) 182 if tcase.expErr == "" { 183 require.NoError(t, err) 184 require.NoError(t, st.NotUnshardedErr) 185 require.NoError(t, st.NotSingleRouteErr) 186 found := 0 187 outer: 188 for _, selExpr := range selectStatement.SelectExprs { 189 aliasedExpr, isAliased := selExpr.(*sqlparser.AliasedExpr) 190 if !isAliased { 191 continue 192 } 193 for _, tbl := range st.ExpandedColumns { 194 for _, col := range tbl { 195 if sqlparser.Equals.Expr(aliasedExpr.Expr, col) { 196 found++ 197 continue outer 198 } 199 } 200 } 201 } 202 if tcase.colExpandedNumber == 0 { 203 for _, tbl := range st.ExpandedColumns { 204 found -= len(tbl) 205 } 206 require.Zero(t, found) 207 } else { 208 require.Equal(t, tcase.colExpandedNumber, found) 209 } 210 assert.Equal(t, tcase.expSQL, sqlparser.String(selectStatement)) 211 } else { 212 require.EqualError(t, err, tcase.expErr) 213 } 214 }) 215 } 216 } 217 218 func TestRewriteJoinUsingColumns(t *testing.T) { 219 schemaInfo := &FakeSI{ 220 Tables: map[string]*vindexes.Table{ 221 "t1": { 222 Name: sqlparser.NewIdentifierCS("t1"), 223 Columns: []vindexes.Column{{ 224 Name: sqlparser.NewIdentifierCI("a"), 225 Type: sqltypes.VarChar, 226 }, { 227 Name: sqlparser.NewIdentifierCI("b"), 228 Type: sqltypes.VarChar, 229 }, { 230 Name: sqlparser.NewIdentifierCI("c"), 231 Type: sqltypes.VarChar, 232 }}, 233 ColumnListAuthoritative: true, 234 }, 235 "t2": { 236 Name: sqlparser.NewIdentifierCS("t2"), 237 Columns: []vindexes.Column{{ 238 Name: sqlparser.NewIdentifierCI("a"), 239 Type: sqltypes.VarChar, 240 }, { 241 Name: sqlparser.NewIdentifierCI("b"), 242 Type: sqltypes.VarChar, 243 }, { 244 Name: sqlparser.NewIdentifierCI("c"), 245 Type: sqltypes.VarChar, 246 }}, 247 ColumnListAuthoritative: true, 248 }, 249 "t3": { 250 Name: sqlparser.NewIdentifierCS("t3"), 251 Columns: []vindexes.Column{{ 252 Name: sqlparser.NewIdentifierCI("a"), 253 Type: sqltypes.VarChar, 254 }, { 255 Name: sqlparser.NewIdentifierCI("b"), 256 Type: sqltypes.VarChar, 257 }, { 258 Name: sqlparser.NewIdentifierCI("c"), 259 Type: sqltypes.VarChar, 260 }}, 261 ColumnListAuthoritative: true, 262 }, 263 }, 264 } 265 cDB := "db" 266 tcases := []struct { 267 sql string 268 expSQL string 269 expErr string 270 }{{ 271 sql: "select 1 from t1 join t2 using (a) where a = 42", 272 expSQL: "select 1 from t1 join t2 where t1.a = t2.a and t1.a = 42", 273 }, { 274 sql: "select 1 from t1 join t2 using (a), t3 where a = 42", 275 expErr: "Column 'a' in field list is ambiguous", 276 }, { 277 sql: "select 1 from t1 join t2 using (a), t1 as b join t3 on (a) where a = 42", 278 expErr: "Column 'a' in field list is ambiguous", 279 }} 280 for _, tcase := range tcases { 281 t.Run(tcase.sql, func(t *testing.T) { 282 ast, err := sqlparser.Parse(tcase.sql) 283 require.NoError(t, err) 284 selectStatement, isSelectStatement := ast.(*sqlparser.Select) 285 require.True(t, isSelectStatement, "analyzer expects a select statement") 286 _, err = Analyze(selectStatement, cDB, schemaInfo) 287 if tcase.expErr == "" { 288 require.NoError(t, err) 289 assert.Equal(t, tcase.expSQL, sqlparser.String(selectStatement)) 290 } else { 291 require.EqualError(t, err, tcase.expErr) 292 } 293 }) 294 } 295 296 } 297 298 func TestOrderByGroupByLiteral(t *testing.T) { 299 schemaInfo := &FakeSI{ 300 Tables: map[string]*vindexes.Table{}, 301 } 302 cDB := "db" 303 tcases := []struct { 304 sql string 305 expSQL string 306 expErr string 307 }{{ 308 sql: "select 1 as id from t1 order by 1", 309 expSQL: "select 1 as id from t1 order by id asc", 310 }, { 311 sql: "select t1.col from t1 order by 1", 312 expSQL: "select t1.col from t1 order by t1.col asc", 313 }, { 314 sql: "select t1.col from t1 group by 1", 315 expSQL: "select t1.col from t1 group by t1.col", 316 }, { 317 sql: "select t1.col as xyz from t1 group by 1", 318 expSQL: "select t1.col as xyz from t1 group by xyz", 319 }, { 320 sql: "select t1.col as xyz, count(*) from t1 group by 1 order by 2", 321 expSQL: "select t1.col as xyz, count(*) from t1 group by xyz order by count(*) asc", 322 }, { 323 sql: "select id from t1 group by 2", 324 expErr: "Unknown column '2' in 'group statement'", 325 }, { 326 sql: "select id from t1 order by 2", 327 expErr: "Unknown column '2' in 'order clause'", 328 }, { 329 sql: "select *, id from t1 order by 2", 330 expErr: "cannot use column offsets in order clause when using `*`", 331 }, { 332 sql: "select *, id from t1 group by 2", 333 expErr: "cannot use column offsets in group statement when using `*`", 334 }, { 335 sql: "select id from t1 order by 1 collate utf8_general_ci", 336 expSQL: "select id from t1 order by id collate utf8_general_ci asc", 337 }} 338 for _, tcase := range tcases { 339 t.Run(tcase.sql, func(t *testing.T) { 340 ast, err := sqlparser.Parse(tcase.sql) 341 require.NoError(t, err) 342 selectStatement := ast.(*sqlparser.Select) 343 _, err = Analyze(selectStatement, cDB, schemaInfo) 344 if tcase.expErr == "" { 345 require.NoError(t, err) 346 assert.Equal(t, tcase.expSQL, sqlparser.String(selectStatement)) 347 } else { 348 require.EqualError(t, err, tcase.expErr) 349 } 350 }) 351 } 352 } 353 354 func TestHavingAndOrderByColumnName(t *testing.T) { 355 schemaInfo := &FakeSI{ 356 Tables: map[string]*vindexes.Table{}, 357 } 358 cDB := "db" 359 tcases := []struct { 360 sql string 361 expSQL string 362 expErr string 363 }{{ 364 sql: "select id, sum(foo) as sumOfFoo from t1 having sumOfFoo > 1", 365 expSQL: "select id, sum(foo) as sumOfFoo from t1 having sum(foo) > 1", 366 }, { 367 sql: "select id, sum(foo) as sumOfFoo from t1 order by sumOfFoo", 368 expSQL: "select id, sum(foo) as sumOfFoo from t1 order by sum(foo) asc", 369 }, { 370 sql: "select id, sum(foo) as foo from t1 having sum(foo) > 1", 371 expSQL: "select id, sum(foo) as foo from t1 having sum(foo) > 1", 372 }} 373 for _, tcase := range tcases { 374 t.Run(tcase.sql, func(t *testing.T) { 375 ast, err := sqlparser.Parse(tcase.sql) 376 require.NoError(t, err) 377 selectStatement := ast.(*sqlparser.Select) 378 _, err = Analyze(selectStatement, cDB, schemaInfo) 379 if tcase.expErr == "" { 380 require.NoError(t, err) 381 assert.Equal(t, tcase.expSQL, sqlparser.String(selectStatement)) 382 } else { 383 require.EqualError(t, err, tcase.expErr) 384 } 385 }) 386 } 387 } 388 389 func TestSemTableDependenciesAfterExpandStar(t *testing.T) { 390 schemaInfo := &FakeSI{Tables: map[string]*vindexes.Table{ 391 "t1": { 392 Name: sqlparser.NewIdentifierCS("t1"), 393 Columns: []vindexes.Column{{ 394 Name: sqlparser.NewIdentifierCI("a"), 395 Type: sqltypes.VarChar, 396 }}, 397 ColumnListAuthoritative: true, 398 }}} 399 tcases := []struct { 400 sql string 401 expSQL string 402 sameTbl int 403 otherTbl int 404 expandedCol int 405 }{{ 406 sql: "select a, * from t1", 407 expSQL: "select a, a from t1", 408 otherTbl: -1, sameTbl: 0, expandedCol: 1, 409 }, { 410 sql: "select t2.a, t1.a, t1.* from t1, t2", 411 expSQL: "select t2.a, t1.a, t1.a as a from t1, t2", 412 otherTbl: 0, sameTbl: 1, expandedCol: 2, 413 }, { 414 sql: "select t2.a, t.a, t.* from t1 t, t2", 415 expSQL: "select t2.a, t.a, t.a as a from t1 as t, t2", 416 otherTbl: 0, sameTbl: 1, expandedCol: 2, 417 }} 418 for _, tcase := range tcases { 419 t.Run(tcase.sql, func(t *testing.T) { 420 ast, err := sqlparser.Parse(tcase.sql) 421 require.NoError(t, err) 422 selectStatement, isSelectStatement := ast.(*sqlparser.Select) 423 require.True(t, isSelectStatement, "analyzer expects a select statement") 424 semTable, err := Analyze(selectStatement, "", schemaInfo) 425 require.NoError(t, err) 426 assert.Equal(t, tcase.expSQL, sqlparser.String(selectStatement)) 427 if tcase.otherTbl != -1 { 428 assert.NotEqual(t, 429 semTable.RecursiveDeps(selectStatement.SelectExprs[tcase.otherTbl].(*sqlparser.AliasedExpr).Expr), 430 semTable.RecursiveDeps(selectStatement.SelectExprs[tcase.expandedCol].(*sqlparser.AliasedExpr).Expr), 431 ) 432 } 433 if tcase.sameTbl != -1 { 434 assert.Equal(t, 435 semTable.RecursiveDeps(selectStatement.SelectExprs[tcase.sameTbl].(*sqlparser.AliasedExpr).Expr), 436 semTable.RecursiveDeps(selectStatement.SelectExprs[tcase.expandedCol].(*sqlparser.AliasedExpr).Expr), 437 ) 438 } 439 }) 440 } 441 }