vitess.io/vitess@v0.16.2/go/vt/sqlparser/ast_rewriting_test.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package sqlparser
    18  
    19  import (
    20  	"fmt"
    21  	"testing"
    22  
    23  	"github.com/stretchr/testify/assert"
    24  
    25  	"vitess.io/vitess/go/vt/sysvars"
    26  
    27  	"github.com/stretchr/testify/require"
    28  )
    29  
    30  type testCaseSetVar struct {
    31  	in, expected, setVarComment string
    32  }
    33  
    34  type testCaseSysVar struct {
    35  	in, expected string
    36  	sysVar       map[string]string
    37  }
    38  
    39  type myTestCase struct {
    40  	in, expected                                                          string
    41  	liid, db, foundRows, rowCount, rawGTID, rawTimeout, sessTrackGTID     bool
    42  	ddlStrategy, sessionUUID, sessionEnableSystemSettings                 bool
    43  	udv                                                                   int
    44  	autocommit, clientFoundRows, skipQueryPlanCache, socket, queryTimeout bool
    45  	sqlSelectLimit, transactionMode, workload, version, versionComment    bool
    46  	txIsolation                                                           bool
    47  }
    48  
    49  func TestRewrites(in *testing.T) {
    50  	tests := []myTestCase{{
    51  		in:       "SELECT 42",
    52  		expected: "SELECT 42",
    53  		// no bindvar needs
    54  	}, {
    55  		in:       "SELECT @@version",
    56  		expected: "SELECT :__vtversion as `@@version`",
    57  		version:  true,
    58  	}, {
    59  		in:           "SELECT @@query_timeout",
    60  		expected:     "SELECT :__vtquery_timeout as `@@query_timeout`",
    61  		queryTimeout: true,
    62  	}, {
    63  		in:             "SELECT @@version_comment",
    64  		expected:       "SELECT :__vtversion_comment as `@@version_comment`",
    65  		versionComment: true,
    66  	}, {
    67  		in:                          "SELECT @@enable_system_settings",
    68  		expected:                    "SELECT :__vtenable_system_settings as `@@enable_system_settings`",
    69  		sessionEnableSystemSettings: true,
    70  	}, {
    71  		in:       "SELECT last_insert_id()",
    72  		expected: "SELECT :__lastInsertId as `last_insert_id()`",
    73  		liid:     true,
    74  	}, {
    75  		in:       "SELECT database()",
    76  		expected: "SELECT :__vtdbname as `database()`",
    77  		db:       true,
    78  	}, {
    79  		in:       "SELECT database() from test",
    80  		expected: "SELECT database() from test",
    81  		// no bindvar needs
    82  	}, {
    83  		in:       "SELECT last_insert_id() as test",
    84  		expected: "SELECT :__lastInsertId as test",
    85  		liid:     true,
    86  	}, {
    87  		in:       "SELECT last_insert_id() + database()",
    88  		expected: "SELECT :__lastInsertId + :__vtdbname as `last_insert_id() + database()`",
    89  		db:       true, liid: true,
    90  	}, {
    91  		// unnest database() call
    92  		in:       "select (select database()) from test",
    93  		expected: "select database() as `(select database() from dual)` from test",
    94  		// no bindvar needs
    95  	}, {
    96  		// unnest database() call
    97  		in:       "select (select database() from dual) from test",
    98  		expected: "select database() as `(select database() from dual)` from test",
    99  		// no bindvar needs
   100  	}, {
   101  		in:       "select (select database() from dual) from dual",
   102  		expected: "select :__vtdbname as `(select database() from dual)` from dual",
   103  		db:       true,
   104  	}, {
   105  		// don't unnest solo columns
   106  		in:       "select 1 as foobar, (select foobar)",
   107  		expected: "select 1 as foobar, (select foobar from dual) from dual",
   108  	}, {
   109  		in:       "select id from user where database()",
   110  		expected: "select id from user where database()",
   111  		// no bindvar needs
   112  	}, {
   113  		in:       "select table_name from information_schema.tables where table_schema = database()",
   114  		expected: "select table_name from information_schema.tables where table_schema = database()",
   115  		// no bindvar needs
   116  	}, {
   117  		in:       "select schema()",
   118  		expected: "select :__vtdbname as `schema()`",
   119  		db:       true,
   120  	}, {
   121  		in:        "select found_rows()",
   122  		expected:  "select :__vtfrows as `found_rows()`",
   123  		foundRows: true,
   124  	}, {
   125  		in:       "select @`x y`",
   126  		expected: "select :__vtudvx_y as `@``x y``` from dual",
   127  		udv:      1,
   128  	}, {
   129  		in:       "select id from t where id = @x and val = @y",
   130  		expected: "select id from t where id = :__vtudvx and val = :__vtudvy",
   131  		db:       false, udv: 2,
   132  	}, {
   133  		in:       "insert into t(id) values(@xyx)",
   134  		expected: "insert into t(id) values(:__vtudvxyx)",
   135  		db:       false, udv: 1,
   136  	}, {
   137  		in:       "select row_count()",
   138  		expected: "select :__vtrcount as `row_count()`",
   139  		rowCount: true,
   140  	}, {
   141  		in:       "SELECT lower(database())",
   142  		expected: "SELECT lower(:__vtdbname) as `lower(database())`",
   143  		db:       true,
   144  	}, {
   145  		in:         "SELECT @@autocommit",
   146  		expected:   "SELECT :__vtautocommit as `@@autocommit`",
   147  		autocommit: true,
   148  	}, {
   149  		in:              "SELECT @@client_found_rows",
   150  		expected:        "SELECT :__vtclient_found_rows as `@@client_found_rows`",
   151  		clientFoundRows: true,
   152  	}, {
   153  		in:                 "SELECT @@skip_query_plan_cache",
   154  		expected:           "SELECT :__vtskip_query_plan_cache as `@@skip_query_plan_cache`",
   155  		skipQueryPlanCache: true,
   156  	}, {
   157  		in:             "SELECT @@sql_select_limit",
   158  		expected:       "SELECT :__vtsql_select_limit as `@@sql_select_limit`",
   159  		sqlSelectLimit: true,
   160  	}, {
   161  		in:              "SELECT @@transaction_mode",
   162  		expected:        "SELECT :__vttransaction_mode as `@@transaction_mode`",
   163  		transactionMode: true,
   164  	}, {
   165  		in:       "SELECT @@workload",
   166  		expected: "SELECT :__vtworkload as `@@workload`",
   167  		workload: true,
   168  	}, {
   169  		in:       "SELECT @@socket",
   170  		expected: "SELECT :__vtsocket as `@@socket`",
   171  		socket:   true,
   172  	}, {
   173  		in:       "select (select 42) from dual",
   174  		expected: "select 42 as `(select 42 from dual)` from dual",
   175  	}, {
   176  		in:       "select exists(select 1) from user",
   177  		expected: "select exists(select 1 limit 1) from user",
   178  	}, {
   179  		in:       "select * from user where col = (select 42)",
   180  		expected: "select * from user where col = 42",
   181  	}, {
   182  		in:       "select * from (select 42) as t", // this is not an expression, and should not be rewritten
   183  		expected: "select * from (select 42) as t",
   184  	}, {
   185  		in:       `select (select (select (select (select (select last_insert_id()))))) as x`,
   186  		expected: "select :__lastInsertId as x from dual",
   187  		liid:     true,
   188  	}, {
   189  		in:          `select * from user where col = @@ddl_strategy`,
   190  		expected:    "select * from user where col = :__vtddl_strategy",
   191  		ddlStrategy: true,
   192  	}, {
   193  		in:       `select * from user where col = @@read_after_write_gtid OR col = @@read_after_write_timeout OR col = @@session_track_gtids`,
   194  		expected: "select * from user where col = :__vtread_after_write_gtid or col = :__vtread_after_write_timeout or col = :__vtsession_track_gtids",
   195  		rawGTID:  true, rawTimeout: true, sessTrackGTID: true,
   196  	}, {
   197  		in:       "SELECT * FROM tbl WHERE id IN (SELECT 1 FROM dual)",
   198  		expected: "SELECT * FROM tbl WHERE id IN (1)",
   199  	}, {
   200  		in:       "SELECT * FROM tbl WHERE id IN (SELECT last_insert_id() FROM dual)",
   201  		expected: "SELECT * FROM tbl WHERE id IN (:__lastInsertId)",
   202  		liid:     true,
   203  	}, {
   204  		in:       "SELECT * FROM tbl WHERE id IN (SELECT (SELECT 1 FROM dual WHERE 1 = 0) FROM dual)",
   205  		expected: "SELECT * FROM tbl WHERE id IN (SELECT 1 FROM dual WHERE 1 = 0)",
   206  	}, {
   207  		in:       "SELECT * FROM tbl WHERE id IN (SELECT 1 FROM dual WHERE 1 = 0)",
   208  		expected: "SELECT * FROM tbl WHERE id IN (SELECT 1 FROM dual WHERE 1 = 0)",
   209  	}, {
   210  		in:       "SELECT * FROM tbl WHERE id IN (SELECT 1,2 FROM dual)",
   211  		expected: "SELECT * FROM tbl WHERE id IN (SELECT 1,2 FROM dual)",
   212  	}, {
   213  		in:       "SELECT * FROM tbl WHERE id IN (SELECT 1 FROM dual ORDER BY 1)",
   214  		expected: "SELECT * FROM tbl WHERE id IN (SELECT 1 FROM dual ORDER BY 1)",
   215  	}, {
   216  		in:       "SELECT * FROM tbl WHERE id IN (SELECT id FROM user GROUP BY id)",
   217  		expected: "SELECT * FROM tbl WHERE id IN (SELECT id FROM user GROUP BY id)",
   218  	}, {
   219  		in:       "SELECT * FROM tbl WHERE id IN (SELECT 1 FROM dual, user)",
   220  		expected: "SELECT * FROM tbl WHERE id IN (SELECT 1 FROM dual, user)",
   221  	}, {
   222  		in:       "SELECT * FROM tbl WHERE id IN (SELECT 1 FROM dual limit 1)",
   223  		expected: "SELECT * FROM tbl WHERE id IN (SELECT 1 FROM dual limit 1)",
   224  	}, {
   225  		// SELECT * behaves different depending the join type used, so if that has been used, we won't rewrite
   226  		in:       "SELECT * FROM A JOIN B USING (id1,id2,id3)",
   227  		expected: "SELECT * FROM A JOIN B USING (id1,id2,id3)",
   228  	}, {
   229  		in:       "CALL proc(@foo)",
   230  		expected: "CALL proc(:__vtudvfoo)",
   231  		udv:      1,
   232  	}, {
   233  		in:       "SELECT * FROM tbl WHERE NOT id = 42",
   234  		expected: "SELECT * FROM tbl WHERE id != 42",
   235  	}, {
   236  		in:       "SELECT * FROM tbl WHERE not id < 12",
   237  		expected: "SELECT * FROM tbl WHERE id >= 12",
   238  	}, {
   239  		in:       "SELECT * FROM tbl WHERE not id > 12",
   240  		expected: "SELECT * FROM tbl WHERE id <= 12",
   241  	}, {
   242  		in:       "SELECT * FROM tbl WHERE not id <= 33",
   243  		expected: "SELECT * FROM tbl WHERE id > 33",
   244  	}, {
   245  		in:       "SELECT * FROM tbl WHERE not id >= 33",
   246  		expected: "SELECT * FROM tbl WHERE id < 33",
   247  	}, {
   248  		in:       "SELECT * FROM tbl WHERE not id != 33",
   249  		expected: "SELECT * FROM tbl WHERE id = 33",
   250  	}, {
   251  		in:       "SELECT * FROM tbl WHERE not id in (1,2,3)",
   252  		expected: "SELECT * FROM tbl WHERE id not in (1,2,3)",
   253  	}, {
   254  		in:       "SELECT * FROM tbl WHERE not id not in (1,2,3)",
   255  		expected: "SELECT * FROM tbl WHERE id in (1,2,3)",
   256  	}, {
   257  		in:       "SELECT * FROM tbl WHERE not id not in (1,2,3)",
   258  		expected: "SELECT * FROM tbl WHERE id in (1,2,3)",
   259  	}, {
   260  		in:       "SELECT * FROM tbl WHERE not id like '%foobar'",
   261  		expected: "SELECT * FROM tbl WHERE id not like '%foobar'",
   262  	}, {
   263  		in:       "SELECT * FROM tbl WHERE not id not like '%foobar'",
   264  		expected: "SELECT * FROM tbl WHERE id like '%foobar'",
   265  	}, {
   266  		in:       "SELECT * FROM tbl WHERE not id regexp '%foobar'",
   267  		expected: "SELECT * FROM tbl WHERE id not regexp '%foobar'",
   268  	}, {
   269  		in:       "SELECT * FROM tbl WHERE not id not regexp '%foobar'",
   270  		expected: "select * from tbl where id regexp '%foobar'",
   271  	}, {
   272  		in:       "SELECT * FROM tbl WHERE exists(select col1, col2 from other_table where foo > bar)",
   273  		expected: "SELECT * FROM tbl WHERE exists(select 1 from other_table where foo > bar limit 1)",
   274  	}, {
   275  		in:       "SELECT * FROM tbl WHERE exists(select col1, col2 from other_table where foo > bar limit 100 offset 34)",
   276  		expected: "SELECT * FROM tbl WHERE exists(select 1 from other_table where foo > bar limit 1 offset 34)",
   277  	}, {
   278  		in:       "SELECT * FROM tbl WHERE exists(select col1, col2, count(*) from other_table where foo > bar group by col1, col2)",
   279  		expected: "SELECT * FROM tbl WHERE exists(select 1 from other_table where foo > bar limit 1)",
   280  	}, {
   281  		in:       "SELECT * FROM tbl WHERE exists(select col1, col2 from other_table where foo > bar group by col1, col2)",
   282  		expected: "SELECT * FROM tbl WHERE exists(select 1 from other_table where foo > bar limit 1)",
   283  	}, {
   284  		in:       "SELECT * FROM tbl WHERE exists(select count(*) from other_table where foo > bar)",
   285  		expected: "SELECT * FROM tbl WHERE true",
   286  	}, {
   287  		in:       "SELECT * FROM tbl WHERE exists(select col1, col2, count(*) from other_table where foo > bar group by col1, col2 having count(*) > 3)",
   288  		expected: "SELECT * FROM tbl WHERE exists(select col1, col2, count(*) from other_table where foo > bar group by col1, col2 having count(*) > 3 limit 1)",
   289  	}, {
   290  		in:       "SELECT id, name, salary FROM user_details",
   291  		expected: "SELECT id, name, salary FROM (select user.id, user.name, user_extra.salary from user join user_extra where user.id = user_extra.user_id) as user_details",
   292  	}, {
   293  		in:                          "SHOW VARIABLES",
   294  		expected:                    "SHOW VARIABLES",
   295  		autocommit:                  true,
   296  		clientFoundRows:             true,
   297  		skipQueryPlanCache:          true,
   298  		sqlSelectLimit:              true,
   299  		transactionMode:             true,
   300  		workload:                    true,
   301  		version:                     true,
   302  		versionComment:              true,
   303  		ddlStrategy:                 true,
   304  		sessionUUID:                 true,
   305  		sessionEnableSystemSettings: true,
   306  		rawGTID:                     true,
   307  		rawTimeout:                  true,
   308  		sessTrackGTID:               true,
   309  		socket:                      true,
   310  		queryTimeout:                true,
   311  	}, {
   312  		in:                          "SHOW GLOBAL VARIABLES",
   313  		expected:                    "SHOW GLOBAL VARIABLES",
   314  		autocommit:                  true,
   315  		clientFoundRows:             true,
   316  		skipQueryPlanCache:          true,
   317  		sqlSelectLimit:              true,
   318  		transactionMode:             true,
   319  		workload:                    true,
   320  		version:                     true,
   321  		versionComment:              true,
   322  		ddlStrategy:                 true,
   323  		sessionUUID:                 true,
   324  		sessionEnableSystemSettings: true,
   325  		rawGTID:                     true,
   326  		rawTimeout:                  true,
   327  		sessTrackGTID:               true,
   328  		socket:                      true,
   329  		queryTimeout:                true,
   330  	}}
   331  
   332  	for _, tc := range tests {
   333  		in.Run(tc.in, func(t *testing.T) {
   334  			require := require.New(t)
   335  			stmt, err := Parse(tc.in)
   336  			require.NoError(err)
   337  
   338  			result, err := RewriteAST(
   339  				stmt,
   340  				"ks", // passing `ks` just to test that no rewriting happens as it is not system schema
   341  				SQLSelectLimitUnset,
   342  				"",
   343  				nil,
   344  				&fakeViews{},
   345  			)
   346  			require.NoError(err)
   347  
   348  			expected, err := Parse(tc.expected)
   349  			require.NoError(err, "test expectation does not parse [%s]", tc.expected)
   350  
   351  			s := String(expected)
   352  			assert := assert.New(t)
   353  			assert.Equal(s, String(result.AST))
   354  			assert.Equal(tc.liid, result.NeedsFuncResult(LastInsertIDName), "should need last insert id")
   355  			assert.Equal(tc.db, result.NeedsFuncResult(DBVarName), "should need database name")
   356  			assert.Equal(tc.foundRows, result.NeedsFuncResult(FoundRowsName), "should need found rows")
   357  			assert.Equal(tc.rowCount, result.NeedsFuncResult(RowCountName), "should need row count")
   358  			assert.Equal(tc.udv, len(result.NeedUserDefinedVariables), "count of user defined variables")
   359  			assert.Equal(tc.autocommit, result.NeedsSysVar(sysvars.Autocommit.Name), "should need :__vtautocommit")
   360  			assert.Equal(tc.clientFoundRows, result.NeedsSysVar(sysvars.ClientFoundRows.Name), "should need :__vtclientFoundRows")
   361  			assert.Equal(tc.skipQueryPlanCache, result.NeedsSysVar(sysvars.SkipQueryPlanCache.Name), "should need :__vtskipQueryPlanCache")
   362  			assert.Equal(tc.sqlSelectLimit, result.NeedsSysVar(sysvars.SQLSelectLimit.Name), "should need :__vtsqlSelectLimit")
   363  			assert.Equal(tc.transactionMode, result.NeedsSysVar(sysvars.TransactionMode.Name), "should need :__vttransactionMode")
   364  			assert.Equal(tc.workload, result.NeedsSysVar(sysvars.Workload.Name), "should need :__vtworkload")
   365  			assert.Equal(tc.queryTimeout, result.NeedsSysVar(sysvars.QueryTimeout.Name), "should need :__vtquery_timeout")
   366  			assert.Equal(tc.ddlStrategy, result.NeedsSysVar(sysvars.DDLStrategy.Name), "should need ddlStrategy")
   367  			assert.Equal(tc.sessionUUID, result.NeedsSysVar(sysvars.SessionUUID.Name), "should need sessionUUID")
   368  			assert.Equal(tc.sessionEnableSystemSettings, result.NeedsSysVar(sysvars.SessionEnableSystemSettings.Name), "should need sessionEnableSystemSettings")
   369  			assert.Equal(tc.rawGTID, result.NeedsSysVar(sysvars.ReadAfterWriteGTID.Name), "should need rawGTID")
   370  			assert.Equal(tc.rawTimeout, result.NeedsSysVar(sysvars.ReadAfterWriteTimeOut.Name), "should need rawTimeout")
   371  			assert.Equal(tc.sessTrackGTID, result.NeedsSysVar(sysvars.SessionTrackGTIDs.Name), "should need sessTrackGTID")
   372  			assert.Equal(tc.version, result.NeedsSysVar(sysvars.Version.Name), "should need Vitess version")
   373  			assert.Equal(tc.versionComment, result.NeedsSysVar(sysvars.VersionComment.Name), "should need Vitess version")
   374  			assert.Equal(tc.socket, result.NeedsSysVar(sysvars.Socket.Name), "should need :__vtsocket")
   375  		})
   376  	}
   377  }
   378  
   379  type fakeViews struct{}
   380  
   381  func (*fakeViews) FindView(name TableName) SelectStatement {
   382  	if name.Name.String() != "user_details" {
   383  		return nil
   384  	}
   385  	statement, err := Parse("select user.id, user.name, user_extra.salary from user join user_extra where user.id = user_extra.user_id")
   386  	if err != nil {
   387  		return nil
   388  	}
   389  	return statement.(SelectStatement)
   390  }
   391  
   392  func TestRewritesWithSetVarComment(in *testing.T) {
   393  	tests := []testCaseSetVar{{
   394  		in:            "select 1",
   395  		expected:      "select 1",
   396  		setVarComment: "",
   397  	}, {
   398  		in:            "select 1",
   399  		expected:      "select /*+ AA(a) */ 1",
   400  		setVarComment: "AA(a)",
   401  	}, {
   402  		in:            "insert /* toto */ into t(id) values(1)",
   403  		expected:      "insert /*+ AA(a) */ /* toto */ into t(id) values(1)",
   404  		setVarComment: "AA(a)",
   405  	}, {
   406  		in:            "select  /* toto */ * from t union select * from s",
   407  		expected:      "select /*+ AA(a) */ /* toto */ * from t union select /*+ AA(a) */ * from s",
   408  		setVarComment: "AA(a)",
   409  	}, {
   410  		in:            "vstream /* toto */ * from t1",
   411  		expected:      "vstream /*+ AA(a) */ /* toto */ * from t1",
   412  		setVarComment: "AA(a)",
   413  	}, {
   414  		in:            "stream /* toto */ t from t1",
   415  		expected:      "stream /*+ AA(a) */ /* toto */ t from t1",
   416  		setVarComment: "AA(a)",
   417  	}, {
   418  		in:            "update /* toto */ t set id = 1",
   419  		expected:      "update /*+ AA(a) */ /* toto */ t set id = 1",
   420  		setVarComment: "AA(a)",
   421  	}, {
   422  		in:            "delete /* toto */ from t",
   423  		expected:      "delete /*+ AA(a) */ /* toto */ from t",
   424  		setVarComment: "AA(a)",
   425  	}}
   426  
   427  	for _, tc := range tests {
   428  		in.Run(tc.in, func(t *testing.T) {
   429  			require := require.New(t)
   430  			stmt, err := Parse(tc.in)
   431  			require.NoError(err)
   432  
   433  			result, err := RewriteAST(stmt, "ks", SQLSelectLimitUnset, tc.setVarComment, nil, &fakeViews{})
   434  			require.NoError(err)
   435  
   436  			expected, err := Parse(tc.expected)
   437  			require.NoError(err, "test expectation does not parse [%s]", tc.expected)
   438  
   439  			assert.Equal(t, String(expected), String(result.AST))
   440  		})
   441  	}
   442  }
   443  
   444  func TestRewritesSysVar(in *testing.T) {
   445  	tests := []testCaseSysVar{{
   446  		in:       "select @x = @@sql_mode",
   447  		expected: "select :__vtudvx = @@sql_mode as `@x = @@sql_mode` from dual",
   448  	}, {
   449  		in:       "select @x = @@sql_mode",
   450  		expected: "select :__vtudvx = :__vtsql_mode as `@x = @@sql_mode` from dual",
   451  		sysVar:   map[string]string{"sql_mode": "' '"},
   452  	}, {
   453  		in:       "SELECT @@tx_isolation",
   454  		expected: "select @@tx_isolation from dual",
   455  	}, {
   456  		in:       "SELECT @@transaction_isolation",
   457  		expected: "select @@transaction_isolation from dual",
   458  	}, {
   459  		in:       "SELECT @@session.transaction_isolation",
   460  		expected: "select @@session.transaction_isolation from dual",
   461  	}, {
   462  		in:       "SELECT @@tx_isolation",
   463  		sysVar:   map[string]string{"tx_isolation": "'READ-COMMITTED'"},
   464  		expected: "select :__vttx_isolation as `@@tx_isolation` from dual",
   465  	}, {
   466  		in:       "SELECT @@transaction_isolation",
   467  		sysVar:   map[string]string{"transaction_isolation": "'READ-COMMITTED'"},
   468  		expected: "select :__vttransaction_isolation as `@@transaction_isolation` from dual",
   469  	}, {
   470  		in:       "SELECT @@session.transaction_isolation",
   471  		sysVar:   map[string]string{"transaction_isolation": "'READ-COMMITTED'"},
   472  		expected: "select :__vttransaction_isolation as `@@session.transaction_isolation` from dual",
   473  	}}
   474  
   475  	for _, tc := range tests {
   476  		in.Run(tc.in, func(t *testing.T) {
   477  			require := require.New(t)
   478  			stmt, err := Parse(tc.in)
   479  			require.NoError(err)
   480  
   481  			result, err := RewriteAST(stmt, "ks", SQLSelectLimitUnset, "", tc.sysVar, &fakeViews{})
   482  			require.NoError(err)
   483  
   484  			expected, err := Parse(tc.expected)
   485  			require.NoError(err, "test expectation does not parse [%s]", tc.expected)
   486  
   487  			assert.Equal(t, String(expected), String(result.AST))
   488  		})
   489  	}
   490  }
   491  
   492  func TestRewritesWithDefaultKeyspace(in *testing.T) {
   493  	tests := []myTestCase{{
   494  		in:       "SELECT 1 from x.test",
   495  		expected: "SELECT 1 from x.test", // no change
   496  	}, {
   497  		in:       "SELECT x.col as c from x.test",
   498  		expected: "SELECT x.col as c from x.test", // no change
   499  	}, {
   500  		in:       "SELECT 1 from test",
   501  		expected: "SELECT 1 from sys.test",
   502  	}, {
   503  		in:       "SELECT 1 from test as t",
   504  		expected: "SELECT 1 from sys.test as t",
   505  	}, {
   506  		in:       "SELECT 1 from `test 24` as t",
   507  		expected: "SELECT 1 from sys.`test 24` as t",
   508  	}, {
   509  		in:       "SELECT 1, (select 1 from test) from x.y",
   510  		expected: "SELECT 1, (select 1 from sys.test) from x.y",
   511  	}, {
   512  		in:       "SELECT 1 from (select 2 from test) t",
   513  		expected: "SELECT 1 from (select 2 from sys.test) t",
   514  	}, {
   515  		in:       "SELECT 1 from test where exists (select 2 from test)",
   516  		expected: "SELECT 1 from sys.test where exists (select 1 from sys.test limit 1)",
   517  	}, {
   518  		in:       "SELECT 1 from dual",
   519  		expected: "SELECT 1 from dual",
   520  	}, {
   521  		in:       "SELECT (select 2 from dual) from DUAL",
   522  		expected: "SELECT 2 as `(select 2 from dual)` from DUAL",
   523  	}}
   524  
   525  	for _, tc := range tests {
   526  		in.Run(tc.in, func(t *testing.T) {
   527  			require := require.New(t)
   528  			stmt, err := Parse(tc.in)
   529  			require.NoError(err)
   530  
   531  			result, err := RewriteAST(stmt, "sys", SQLSelectLimitUnset, "", nil, &fakeViews{})
   532  			require.NoError(err)
   533  
   534  			expected, err := Parse(tc.expected)
   535  			require.NoError(err, "test expectation does not parse [%s]", tc.expected)
   536  
   537  			assert.Equal(t, String(expected), String(result.AST))
   538  		})
   539  	}
   540  }
   541  
   542  func TestReservedVars(t *testing.T) {
   543  	for _, prefix := range []string{"vtg", "bv"} {
   544  		t.Run("prefix_"+prefix, func(t *testing.T) {
   545  			reserved := NewReservedVars(prefix, make(BindVars))
   546  			for i := 1; i < 1000; i++ {
   547  				require.Equal(t, fmt.Sprintf("%s%d", prefix, i), reserved.nextUnusedVar())
   548  			}
   549  		})
   550  	}
   551  }