vitess.io/vitess@v0.16.2/go/vt/vtgate/semantics/early_rewriter_test.go (about)

     1  /*
     2  Copyright 2021 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package semantics
    18  
    19  import (
    20  	"testing"
    21  
    22  	"github.com/stretchr/testify/assert"
    23  	"github.com/stretchr/testify/require"
    24  
    25  	"vitess.io/vitess/go/sqltypes"
    26  	"vitess.io/vitess/go/vt/sqlparser"
    27  	"vitess.io/vitess/go/vt/vtgate/vindexes"
    28  )
    29  
    30  func TestExpandStar(t *testing.T) {
    31  	ks := &vindexes.Keyspace{
    32  		Name:    "main",
    33  		Sharded: false,
    34  	}
    35  	schemaInfo := &FakeSI{
    36  		Tables: map[string]*vindexes.Table{
    37  			"t1": {
    38  				Keyspace: ks,
    39  				Name:     sqlparser.NewIdentifierCS("t1"),
    40  				Columns: []vindexes.Column{{
    41  					Name: sqlparser.NewIdentifierCI("a"),
    42  					Type: sqltypes.VarChar,
    43  				}, {
    44  					Name: sqlparser.NewIdentifierCI("b"),
    45  					Type: sqltypes.VarChar,
    46  				}, {
    47  					Name: sqlparser.NewIdentifierCI("c"),
    48  					Type: sqltypes.VarChar,
    49  				}},
    50  				ColumnListAuthoritative: true,
    51  			},
    52  			"t2": {
    53  				Keyspace: ks,
    54  				Name:     sqlparser.NewIdentifierCS("t2"),
    55  				Columns: []vindexes.Column{{
    56  					Name: sqlparser.NewIdentifierCI("c1"),
    57  					Type: sqltypes.VarChar,
    58  				}, {
    59  					Name: sqlparser.NewIdentifierCI("c2"),
    60  					Type: sqltypes.VarChar,
    61  				}},
    62  				ColumnListAuthoritative: true,
    63  			},
    64  			"t3": { // non authoritative table.
    65  				Keyspace: ks,
    66  				Name:     sqlparser.NewIdentifierCS("t3"),
    67  				Columns: []vindexes.Column{{
    68  					Name: sqlparser.NewIdentifierCI("col"),
    69  					Type: sqltypes.VarChar,
    70  				}},
    71  				ColumnListAuthoritative: false,
    72  			},
    73  			"t4": {
    74  				Keyspace: ks,
    75  				Name:     sqlparser.NewIdentifierCS("t4"),
    76  				Columns: []vindexes.Column{{
    77  					Name: sqlparser.NewIdentifierCI("c1"),
    78  					Type: sqltypes.VarChar,
    79  				}, {
    80  					Name: sqlparser.NewIdentifierCI("c4"),
    81  					Type: sqltypes.VarChar,
    82  				}},
    83  				ColumnListAuthoritative: true,
    84  			},
    85  			"t5": {
    86  				Keyspace: ks,
    87  				Name:     sqlparser.NewIdentifierCS("t5"),
    88  				Columns: []vindexes.Column{{
    89  					Name: sqlparser.NewIdentifierCI("a"),
    90  					Type: sqltypes.VarChar,
    91  				}, {
    92  					Name: sqlparser.NewIdentifierCI("b"),
    93  					Type: sqltypes.VarChar,
    94  				}},
    95  				ColumnListAuthoritative: true,
    96  			},
    97  		},
    98  	}
    99  	cDB := "db"
   100  	tcases := []struct {
   101  		sql               string
   102  		expSQL            string
   103  		expErr            string
   104  		colExpandedNumber int
   105  	}{{
   106  		sql:    "select * from t1",
   107  		expSQL: "select a, b, c from t1",
   108  	}, {
   109  		sql:    "select t1.* from t1",
   110  		expSQL: "select a, b, c from t1",
   111  	}, {
   112  		sql:               "select *, 42, t1.* from t1",
   113  		expSQL:            "select a, b, c, 42, a, b, c from t1",
   114  		colExpandedNumber: 6,
   115  	}, {
   116  		sql:    "select 42, t1.* from t1",
   117  		expSQL: "select 42, a, b, c from t1",
   118  	}, {
   119  		sql:    "select * from t1, t2",
   120  		expSQL: "select t1.a as a, t1.b as b, t1.c as c, t2.c1 as c1, t2.c2 as c2 from t1, t2",
   121  	}, {
   122  		sql:    "select t1.* from t1, t2",
   123  		expSQL: "select t1.a as a, t1.b as b, t1.c as c from t1, t2",
   124  	}, {
   125  		sql:               "select *, t1.* from t1, t2",
   126  		expSQL:            "select t1.a as a, t1.b as b, t1.c as c, t2.c1 as c1, t2.c2 as c2, t1.a as a, t1.b as b, t1.c as c from t1, t2",
   127  		colExpandedNumber: 6,
   128  	}, { // aliased table
   129  		sql:    "select * from t1 a, t2 b",
   130  		expSQL: "select a.a as a, a.b as b, a.c as c, b.c1 as c1, b.c2 as c2 from t1 as a, t2 as b",
   131  	}, { // t3 is non-authoritative table
   132  		sql:    "select * from t3",
   133  		expSQL: "select * from t3",
   134  	}, { // t3 is non-authoritative table
   135  		sql:    "select * from t1, t2, t3",
   136  		expSQL: "select * from t1, t2, t3",
   137  	}, { // t3 is non-authoritative table
   138  		sql:    "select t1.*, t2.*, t3.* from t1, t2, t3",
   139  		expSQL: "select t1.a as a, t1.b as b, t1.c as c, t2.c1 as c1, t2.c2 as c2, t3.* from t1, t2, t3",
   140  	}, {
   141  		sql:    "select foo.* from t1, t2",
   142  		expErr: "Unknown table 'foo'",
   143  	}, {
   144  		sql:    "select * from t1 join t2 on t1.a = t2.c1",
   145  		expSQL: "select t1.a as a, t1.b as b, t1.c as c, t2.c1 as c1, t2.c2 as c2 from t1 join t2 on t1.a = t2.c1",
   146  	}, {
   147  		sql:    "select * from t2 join t4 using (c1)",
   148  		expSQL: "select t2.c1 as c1, t2.c2 as c2, t4.c4 as c4 from t2 join t4 where t2.c1 = t4.c1",
   149  	}, {
   150  		sql:    "select * from t2 join t4 using (c1) join t2 as X using (c1)",
   151  		expSQL: "select t2.c1 as c1, t2.c2 as c2, t4.c4 as c4, X.c2 as c2 from t2 join t4 join t2 as X where t2.c1 = t4.c1 and t2.c1 = X.c1 and t4.c1 = X.c1",
   152  	}, {
   153  		sql:    "select * from t2 join t4 using (c1), t2 as t2b join t4 as t4b using (c1)",
   154  		expSQL: "select t2.c1 as c1, t2.c2 as c2, t4.c4 as c4, t2b.c1 as c1, t2b.c2 as c2, t4b.c4 as c4 from t2 join t4, t2 as t2b join t4 as t4b where t2b.c1 = t4b.c1 and t2.c1 = t4.c1",
   155  	}, {
   156  		sql:    "select * from t1 join t5 using (b)",
   157  		expSQL: "select t1.b as b, t1.a as a, t1.c as c, t5.a as a from t1 join t5 where t1.b = t5.b",
   158  	}, {
   159  		sql:    "select * from t1 join t5 using (b) having b = 12",
   160  		expSQL: "select t1.b as b, t1.a as a, t1.c as c, t5.a as a from t1 join t5 where t1.b = t5.b having b = 12",
   161  	}, {
   162  		sql:    "select 1 from t1 join t5 using (b) having b = 12",
   163  		expSQL: "select 1 from t1 join t5 where t1.b = t5.b having t1.b = 12",
   164  	}, {
   165  		sql:    "select * from (select 12) as t",
   166  		expSQL: "select t.`12` from (select 12 from dual) as t",
   167  	}, {
   168  		sql:    "SELECT * FROM (SELECT *, 12 AS foo FROM t3) as results",
   169  		expSQL: "select * from (select *, 12 as foo from t3) as results",
   170  	}, {
   171  		// if we are only star-expanding authoritative tables, we don't need to stop the expansion
   172  		sql:    "SELECT * FROM (SELECT t2.*, 12 AS foo FROM t3, t2) as results",
   173  		expSQL: "select results.c1, results.c2, results.foo from (select t2.c1 as c1, t2.c2 as c2, 12 as foo from t3, t2) as results",
   174  	}}
   175  	for _, tcase := range tcases {
   176  		t.Run(tcase.sql, func(t *testing.T) {
   177  			ast, err := sqlparser.Parse(tcase.sql)
   178  			require.NoError(t, err)
   179  			selectStatement, isSelectStatement := ast.(*sqlparser.Select)
   180  			require.True(t, isSelectStatement, "analyzer expects a select statement")
   181  			st, err := Analyze(selectStatement, cDB, schemaInfo)
   182  			if tcase.expErr == "" {
   183  				require.NoError(t, err)
   184  				require.NoError(t, st.NotUnshardedErr)
   185  				require.NoError(t, st.NotSingleRouteErr)
   186  				found := 0
   187  			outer:
   188  				for _, selExpr := range selectStatement.SelectExprs {
   189  					aliasedExpr, isAliased := selExpr.(*sqlparser.AliasedExpr)
   190  					if !isAliased {
   191  						continue
   192  					}
   193  					for _, tbl := range st.ExpandedColumns {
   194  						for _, col := range tbl {
   195  							if sqlparser.Equals.Expr(aliasedExpr.Expr, col) {
   196  								found++
   197  								continue outer
   198  							}
   199  						}
   200  					}
   201  				}
   202  				if tcase.colExpandedNumber == 0 {
   203  					for _, tbl := range st.ExpandedColumns {
   204  						found -= len(tbl)
   205  					}
   206  					require.Zero(t, found)
   207  				} else {
   208  					require.Equal(t, tcase.colExpandedNumber, found)
   209  				}
   210  				assert.Equal(t, tcase.expSQL, sqlparser.String(selectStatement))
   211  			} else {
   212  				require.EqualError(t, err, tcase.expErr)
   213  			}
   214  		})
   215  	}
   216  }
   217  
   218  func TestRewriteJoinUsingColumns(t *testing.T) {
   219  	schemaInfo := &FakeSI{
   220  		Tables: map[string]*vindexes.Table{
   221  			"t1": {
   222  				Name: sqlparser.NewIdentifierCS("t1"),
   223  				Columns: []vindexes.Column{{
   224  					Name: sqlparser.NewIdentifierCI("a"),
   225  					Type: sqltypes.VarChar,
   226  				}, {
   227  					Name: sqlparser.NewIdentifierCI("b"),
   228  					Type: sqltypes.VarChar,
   229  				}, {
   230  					Name: sqlparser.NewIdentifierCI("c"),
   231  					Type: sqltypes.VarChar,
   232  				}},
   233  				ColumnListAuthoritative: true,
   234  			},
   235  			"t2": {
   236  				Name: sqlparser.NewIdentifierCS("t2"),
   237  				Columns: []vindexes.Column{{
   238  					Name: sqlparser.NewIdentifierCI("a"),
   239  					Type: sqltypes.VarChar,
   240  				}, {
   241  					Name: sqlparser.NewIdentifierCI("b"),
   242  					Type: sqltypes.VarChar,
   243  				}, {
   244  					Name: sqlparser.NewIdentifierCI("c"),
   245  					Type: sqltypes.VarChar,
   246  				}},
   247  				ColumnListAuthoritative: true,
   248  			},
   249  			"t3": {
   250  				Name: sqlparser.NewIdentifierCS("t3"),
   251  				Columns: []vindexes.Column{{
   252  					Name: sqlparser.NewIdentifierCI("a"),
   253  					Type: sqltypes.VarChar,
   254  				}, {
   255  					Name: sqlparser.NewIdentifierCI("b"),
   256  					Type: sqltypes.VarChar,
   257  				}, {
   258  					Name: sqlparser.NewIdentifierCI("c"),
   259  					Type: sqltypes.VarChar,
   260  				}},
   261  				ColumnListAuthoritative: true,
   262  			},
   263  		},
   264  	}
   265  	cDB := "db"
   266  	tcases := []struct {
   267  		sql    string
   268  		expSQL string
   269  		expErr string
   270  	}{{
   271  		sql:    "select 1 from t1 join t2 using (a) where a = 42",
   272  		expSQL: "select 1 from t1 join t2 where t1.a = t2.a and t1.a = 42",
   273  	}, {
   274  		sql:    "select 1 from t1 join t2 using (a), t3 where a = 42",
   275  		expErr: "Column 'a' in field list is ambiguous",
   276  	}, {
   277  		sql:    "select 1 from t1 join t2 using (a), t1 as b join t3 on (a) where a = 42",
   278  		expErr: "Column 'a' in field list is ambiguous",
   279  	}}
   280  	for _, tcase := range tcases {
   281  		t.Run(tcase.sql, func(t *testing.T) {
   282  			ast, err := sqlparser.Parse(tcase.sql)
   283  			require.NoError(t, err)
   284  			selectStatement, isSelectStatement := ast.(*sqlparser.Select)
   285  			require.True(t, isSelectStatement, "analyzer expects a select statement")
   286  			_, err = Analyze(selectStatement, cDB, schemaInfo)
   287  			if tcase.expErr == "" {
   288  				require.NoError(t, err)
   289  				assert.Equal(t, tcase.expSQL, sqlparser.String(selectStatement))
   290  			} else {
   291  				require.EqualError(t, err, tcase.expErr)
   292  			}
   293  		})
   294  	}
   295  
   296  }
   297  
   298  func TestOrderByGroupByLiteral(t *testing.T) {
   299  	schemaInfo := &FakeSI{
   300  		Tables: map[string]*vindexes.Table{},
   301  	}
   302  	cDB := "db"
   303  	tcases := []struct {
   304  		sql    string
   305  		expSQL string
   306  		expErr string
   307  	}{{
   308  		sql:    "select 1 as id from t1 order by 1",
   309  		expSQL: "select 1 as id from t1 order by id asc",
   310  	}, {
   311  		sql:    "select t1.col from t1 order by 1",
   312  		expSQL: "select t1.col from t1 order by t1.col asc",
   313  	}, {
   314  		sql:    "select t1.col from t1 group by 1",
   315  		expSQL: "select t1.col from t1 group by t1.col",
   316  	}, {
   317  		sql:    "select t1.col as xyz from t1 group by 1",
   318  		expSQL: "select t1.col as xyz from t1 group by xyz",
   319  	}, {
   320  		sql:    "select t1.col as xyz, count(*) from t1 group by 1 order by 2",
   321  		expSQL: "select t1.col as xyz, count(*) from t1 group by xyz order by count(*) asc",
   322  	}, {
   323  		sql:    "select id from t1 group by 2",
   324  		expErr: "Unknown column '2' in 'group statement'",
   325  	}, {
   326  		sql:    "select id from t1 order by 2",
   327  		expErr: "Unknown column '2' in 'order clause'",
   328  	}, {
   329  		sql:    "select *, id from t1 order by 2",
   330  		expErr: "cannot use column offsets in order clause when using `*`",
   331  	}, {
   332  		sql:    "select *, id from t1 group by 2",
   333  		expErr: "cannot use column offsets in group statement when using `*`",
   334  	}, {
   335  		sql:    "select id from t1 order by 1 collate utf8_general_ci",
   336  		expSQL: "select id from t1 order by id collate utf8_general_ci asc",
   337  	}}
   338  	for _, tcase := range tcases {
   339  		t.Run(tcase.sql, func(t *testing.T) {
   340  			ast, err := sqlparser.Parse(tcase.sql)
   341  			require.NoError(t, err)
   342  			selectStatement := ast.(*sqlparser.Select)
   343  			_, err = Analyze(selectStatement, cDB, schemaInfo)
   344  			if tcase.expErr == "" {
   345  				require.NoError(t, err)
   346  				assert.Equal(t, tcase.expSQL, sqlparser.String(selectStatement))
   347  			} else {
   348  				require.EqualError(t, err, tcase.expErr)
   349  			}
   350  		})
   351  	}
   352  }
   353  
   354  func TestHavingAndOrderByColumnName(t *testing.T) {
   355  	schemaInfo := &FakeSI{
   356  		Tables: map[string]*vindexes.Table{},
   357  	}
   358  	cDB := "db"
   359  	tcases := []struct {
   360  		sql    string
   361  		expSQL string
   362  		expErr string
   363  	}{{
   364  		sql:    "select id, sum(foo) as sumOfFoo from t1 having sumOfFoo > 1",
   365  		expSQL: "select id, sum(foo) as sumOfFoo from t1 having sum(foo) > 1",
   366  	}, {
   367  		sql:    "select id, sum(foo) as sumOfFoo from t1 order by sumOfFoo",
   368  		expSQL: "select id, sum(foo) as sumOfFoo from t1 order by sum(foo) asc",
   369  	}, {
   370  		sql:    "select id, sum(foo) as foo from t1 having sum(foo) > 1",
   371  		expSQL: "select id, sum(foo) as foo from t1 having sum(foo) > 1",
   372  	}}
   373  	for _, tcase := range tcases {
   374  		t.Run(tcase.sql, func(t *testing.T) {
   375  			ast, err := sqlparser.Parse(tcase.sql)
   376  			require.NoError(t, err)
   377  			selectStatement := ast.(*sqlparser.Select)
   378  			_, err = Analyze(selectStatement, cDB, schemaInfo)
   379  			if tcase.expErr == "" {
   380  				require.NoError(t, err)
   381  				assert.Equal(t, tcase.expSQL, sqlparser.String(selectStatement))
   382  			} else {
   383  				require.EqualError(t, err, tcase.expErr)
   384  			}
   385  		})
   386  	}
   387  }
   388  
   389  func TestSemTableDependenciesAfterExpandStar(t *testing.T) {
   390  	schemaInfo := &FakeSI{Tables: map[string]*vindexes.Table{
   391  		"t1": {
   392  			Name: sqlparser.NewIdentifierCS("t1"),
   393  			Columns: []vindexes.Column{{
   394  				Name: sqlparser.NewIdentifierCI("a"),
   395  				Type: sqltypes.VarChar,
   396  			}},
   397  			ColumnListAuthoritative: true,
   398  		}}}
   399  	tcases := []struct {
   400  		sql         string
   401  		expSQL      string
   402  		sameTbl     int
   403  		otherTbl    int
   404  		expandedCol int
   405  	}{{
   406  		sql:      "select a, * from t1",
   407  		expSQL:   "select a, a from t1",
   408  		otherTbl: -1, sameTbl: 0, expandedCol: 1,
   409  	}, {
   410  		sql:      "select t2.a, t1.a, t1.* from t1, t2",
   411  		expSQL:   "select t2.a, t1.a, t1.a as a from t1, t2",
   412  		otherTbl: 0, sameTbl: 1, expandedCol: 2,
   413  	}, {
   414  		sql:      "select t2.a, t.a, t.* from t1 t, t2",
   415  		expSQL:   "select t2.a, t.a, t.a as a from t1 as t, t2",
   416  		otherTbl: 0, sameTbl: 1, expandedCol: 2,
   417  	}}
   418  	for _, tcase := range tcases {
   419  		t.Run(tcase.sql, func(t *testing.T) {
   420  			ast, err := sqlparser.Parse(tcase.sql)
   421  			require.NoError(t, err)
   422  			selectStatement, isSelectStatement := ast.(*sqlparser.Select)
   423  			require.True(t, isSelectStatement, "analyzer expects a select statement")
   424  			semTable, err := Analyze(selectStatement, "", schemaInfo)
   425  			require.NoError(t, err)
   426  			assert.Equal(t, tcase.expSQL, sqlparser.String(selectStatement))
   427  			if tcase.otherTbl != -1 {
   428  				assert.NotEqual(t,
   429  					semTable.RecursiveDeps(selectStatement.SelectExprs[tcase.otherTbl].(*sqlparser.AliasedExpr).Expr),
   430  					semTable.RecursiveDeps(selectStatement.SelectExprs[tcase.expandedCol].(*sqlparser.AliasedExpr).Expr),
   431  				)
   432  			}
   433  			if tcase.sameTbl != -1 {
   434  				assert.Equal(t,
   435  					semTable.RecursiveDeps(selectStatement.SelectExprs[tcase.sameTbl].(*sqlparser.AliasedExpr).Expr),
   436  					semTable.RecursiveDeps(selectStatement.SelectExprs[tcase.expandedCol].(*sqlparser.AliasedExpr).Expr),
   437  				)
   438  			}
   439  		})
   440  	}
   441  }