vitess.io/vitess@v0.16.2/go/vt/vtgate/semantics/analyzer_test.go (about)

     1  /*
     2  Copyright 2021 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package semantics
    18  
    19  import (
    20  	"fmt"
    21  	"testing"
    22  
    23  	"vitess.io/vitess/go/vt/vtgate/engine"
    24  
    25  	"github.com/stretchr/testify/assert"
    26  	"github.com/stretchr/testify/require"
    27  
    28  	querypb "vitess.io/vitess/go/vt/proto/query"
    29  	"vitess.io/vitess/go/vt/sqlparser"
    30  	"vitess.io/vitess/go/vt/vtgate/vindexes"
    31  )
    32  
    33  var T0 TableSet
    34  
    35  var (
    36  	// Just here to make outputs more readable
    37  	None = EmptyTableSet()
    38  	T1   = SingleTableSet(0)
    39  	T2   = SingleTableSet(1)
    40  	T3   = SingleTableSet(2)
    41  	T4   = SingleTableSet(3)
    42  	T5   = SingleTableSet(4)
    43  )
    44  
    45  func extract(in *sqlparser.Select, idx int) sqlparser.Expr {
    46  	return in.SelectExprs[idx].(*sqlparser.AliasedExpr).Expr
    47  }
    48  
    49  func TestBindingSingleTablePositive(t *testing.T) {
    50  	queries := []string{
    51  		"select col from tabl",
    52  		"select uid from t2",
    53  		"select tabl.col from tabl",
    54  		"select d.tabl.col from tabl",
    55  		"select col from d.tabl",
    56  		"select tabl.col from d.tabl",
    57  		"select d.tabl.col from d.tabl",
    58  		"select col+col from tabl",
    59  		"select max(col1+col2) from d.tabl",
    60  		"select max(id) from t1",
    61  	}
    62  	for _, query := range queries {
    63  		t.Run(query, func(t *testing.T) {
    64  			stmt, semTable := parseAndAnalyze(t, query, "d")
    65  			sel, _ := stmt.(*sqlparser.Select)
    66  			t1 := sel.From[0].(*sqlparser.AliasedTableExpr)
    67  			ts := semTable.TableSetFor(t1)
    68  			assert.Equal(t, SingleTableSet(0), ts)
    69  
    70  			recursiveDeps := semTable.RecursiveDeps(extract(sel, 0))
    71  			assert.Equal(t, T1, recursiveDeps, query)
    72  			assert.Equal(t, T1, semTable.DirectDeps(extract(sel, 0)), query)
    73  			assert.Equal(t, 1, recursiveDeps.NumberOfTables(), "number of tables is wrong")
    74  		})
    75  	}
    76  }
    77  
    78  func TestInformationSchemaColumnInfo(t *testing.T) {
    79  	stmt, semTable := parseAndAnalyze(t, "select table_comment, file_name from information_schema.`TABLES`, information_schema.`FILES`", "d")
    80  
    81  	sel, _ := stmt.(*sqlparser.Select)
    82  	tables := SingleTableSet(0)
    83  	files := SingleTableSet(1)
    84  
    85  	assert.Equal(t, tables, semTable.RecursiveDeps(extract(sel, 0)))
    86  	assert.Equal(t, files, semTable.DirectDeps(extract(sel, 1)))
    87  }
    88  
    89  func TestBindingSingleAliasedTablePositive(t *testing.T) {
    90  	queries := []string{
    91  		"select col from tabl as X",
    92  		"select tabl.col from X as tabl",
    93  		"select col from d.X as tabl",
    94  		"select tabl.col from d.X as tabl",
    95  		"select col+col from tabl as X",
    96  		"select max(tabl.col1 + tabl.col2) from d.X as tabl",
    97  		"select max(t.id) from t1 as t",
    98  	}
    99  	for _, query := range queries {
   100  		t.Run(query, func(t *testing.T) {
   101  			stmt, semTable := parseAndAnalyze(t, query, "")
   102  			sel, _ := stmt.(*sqlparser.Select)
   103  			t1 := sel.From[0].(*sqlparser.AliasedTableExpr)
   104  			ts := semTable.TableSetFor(t1)
   105  			assert.Equal(t, SingleTableSet(0), ts)
   106  
   107  			recursiveDeps := semTable.RecursiveDeps(extract(sel, 0))
   108  			require.Equal(t, T1, recursiveDeps, query)
   109  			assert.Equal(t, 1, recursiveDeps.NumberOfTables(), "number of tables is wrong")
   110  		})
   111  	}
   112  }
   113  
   114  func TestBindingSingleTableNegative(t *testing.T) {
   115  	queries := []string{
   116  		"select foo.col from tabl",
   117  		"select ks.tabl.col from tabl",
   118  		"select ks.tabl.col from d.tabl",
   119  		"select d.tabl.col from ks.tabl",
   120  		"select foo.col from d.tabl",
   121  		"select tabl.col from d.tabl as t",
   122  	}
   123  	for _, query := range queries {
   124  		t.Run(query, func(t *testing.T) {
   125  			parse, err := sqlparser.Parse(query)
   126  			require.NoError(t, err)
   127  			st, err := Analyze(parse.(sqlparser.SelectStatement), "d", &FakeSI{})
   128  			require.NoError(t, err)
   129  			require.ErrorContains(t, st.NotUnshardedErr, "symbol")
   130  			require.ErrorContains(t, st.NotUnshardedErr, "not found")
   131  		})
   132  	}
   133  }
   134  
   135  func TestBindingSingleAliasedTableNegative(t *testing.T) {
   136  	queries := []string{
   137  		"select tabl.col from tabl as X",
   138  		"select d.X.col from d.X as tabl",
   139  		"select d.tabl.col from X as tabl",
   140  		"select d.tabl.col from ks.X as tabl",
   141  		"select d.tabl.col from d.X as tabl",
   142  	}
   143  	for _, query := range queries {
   144  		t.Run(query, func(t *testing.T) {
   145  			parse, err := sqlparser.Parse(query)
   146  			require.NoError(t, err)
   147  			st, err := Analyze(parse.(sqlparser.SelectStatement), "", &FakeSI{
   148  				Tables: map[string]*vindexes.Table{
   149  					"t": {Name: sqlparser.NewIdentifierCS("t")},
   150  				},
   151  			})
   152  			require.NoError(t, err)
   153  			require.Error(t, st.NotUnshardedErr)
   154  		})
   155  	}
   156  }
   157  
   158  func TestBindingMultiTablePositive(t *testing.T) {
   159  	type testCase struct {
   160  		query          string
   161  		deps           TableSet
   162  		numberOfTables int
   163  	}
   164  	queries := []testCase{{
   165  		query:          "select t.col from t, s",
   166  		deps:           T1,
   167  		numberOfTables: 1,
   168  	}, {
   169  		query:          "select s.col from t join s",
   170  		deps:           T2,
   171  		numberOfTables: 1,
   172  	}, {
   173  		query:          "select max(t.col+s.col) from t, s",
   174  		deps:           MergeTableSets(T1, T2),
   175  		numberOfTables: 2,
   176  	}, {
   177  		query:          "select max(t.col+s.col) from t join s",
   178  		deps:           MergeTableSets(T1, T2),
   179  		numberOfTables: 2,
   180  	}, {
   181  		query:          "select case t.col when s.col then r.col else u.col end from t, s, r, w, u",
   182  		deps:           MergeTableSets(T1, T2, T3, T5),
   183  		numberOfTables: 4,
   184  		// }, {
   185  		// TODO: move to subquery
   186  		// make sure that we don't let sub-query dependencies leak out by mistake
   187  		// query: "select t.col + (select 42 from s) from t",
   188  		// deps:  T1,
   189  		// }, {
   190  		// 	query: "select (select 42 from s where r.id = s.id) from r",
   191  		// 	deps:  T1 | T2,
   192  	}, {
   193  		query:          "select u1.a + u2.a from u1, u2",
   194  		deps:           MergeTableSets(T1, T2),
   195  		numberOfTables: 2,
   196  	}}
   197  	for _, query := range queries {
   198  		t.Run(query.query, func(t *testing.T) {
   199  			stmt, semTable := parseAndAnalyze(t, query.query, "user")
   200  			sel, _ := stmt.(*sqlparser.Select)
   201  			recursiveDeps := semTable.RecursiveDeps(extract(sel, 0))
   202  			assert.Equal(t, query.deps, recursiveDeps, query.query)
   203  			assert.Equal(t, query.numberOfTables, recursiveDeps.NumberOfTables(), "number of tables is wrong")
   204  		})
   205  	}
   206  }
   207  
   208  func TestBindingMultiAliasedTablePositive(t *testing.T) {
   209  	type testCase struct {
   210  		query          string
   211  		deps           TableSet
   212  		numberOfTables int
   213  	}
   214  	queries := []testCase{{
   215  		query:          "select X.col from t as X, s as S",
   216  		deps:           T1,
   217  		numberOfTables: 1,
   218  	}, {
   219  		query:          "select X.col+S.col from t as X, s as S",
   220  		deps:           MergeTableSets(T1, T2),
   221  		numberOfTables: 2,
   222  	}, {
   223  		query:          "select max(X.col+S.col) from t as X, s as S",
   224  		deps:           MergeTableSets(T1, T2),
   225  		numberOfTables: 2,
   226  	}, {
   227  		query:          "select max(X.col+s.col) from t as X, s",
   228  		deps:           MergeTableSets(T1, T2),
   229  		numberOfTables: 2,
   230  	}}
   231  	for _, query := range queries {
   232  		t.Run(query.query, func(t *testing.T) {
   233  			stmt, semTable := parseAndAnalyze(t, query.query, "user")
   234  			sel, _ := stmt.(*sqlparser.Select)
   235  			recursiveDeps := semTable.RecursiveDeps(extract(sel, 0))
   236  			assert.Equal(t, query.deps, recursiveDeps, query.query)
   237  			assert.Equal(t, query.numberOfTables, recursiveDeps.NumberOfTables(), "number of tables is wrong")
   238  		})
   239  	}
   240  }
   241  
   242  func TestBindingMultiTableNegative(t *testing.T) {
   243  	queries := []string{
   244  		"select 1 from d.tabl, d.tabl",
   245  		"select 1 from d.tabl, tabl",
   246  		"select t.col from k.t, t",
   247  		"select b.t.col from b.t, t",
   248  	}
   249  	for _, query := range queries {
   250  		t.Run(query, func(t *testing.T) {
   251  			parse, err := sqlparser.Parse(query)
   252  			require.NoError(t, err)
   253  			_, err = Analyze(parse.(sqlparser.SelectStatement), "d", &FakeSI{
   254  				Tables: map[string]*vindexes.Table{
   255  					"tabl": {Name: sqlparser.NewIdentifierCS("tabl")},
   256  					"foo":  {Name: sqlparser.NewIdentifierCS("foo")},
   257  				},
   258  			})
   259  			require.Error(t, err)
   260  		})
   261  	}
   262  }
   263  
   264  func TestBindingMultiAliasedTableNegative(t *testing.T) {
   265  	queries := []string{
   266  		"select 1 from d.tabl as tabl, d.tabl",
   267  		"select 1 from d.tabl as tabl, tabl",
   268  		"select 1 from d.tabl as a, tabl as a",
   269  		"select 1 from user join user_extra user",
   270  		"select t.col from k.t as t, t",
   271  		"select b.t.col from b.t as t, t",
   272  	}
   273  	for _, query := range queries {
   274  		t.Run(query, func(t *testing.T) {
   275  			parse, err := sqlparser.Parse(query)
   276  			require.NoError(t, err)
   277  			_, err = Analyze(parse.(sqlparser.SelectStatement), "d", &FakeSI{
   278  				Tables: map[string]*vindexes.Table{
   279  					"tabl": {Name: sqlparser.NewIdentifierCS("tabl")},
   280  					"foo":  {Name: sqlparser.NewIdentifierCS("foo")},
   281  				},
   282  			})
   283  			require.Error(t, err)
   284  		})
   285  	}
   286  }
   287  
   288  func TestNotUniqueTableName(t *testing.T) {
   289  	queries := []string{
   290  		"select * from t, t",
   291  		"select * from t, (select 1 from x) as t",
   292  		"select * from t join t",
   293  		"select * from t join (select 1 from x) as t",
   294  	}
   295  
   296  	for _, query := range queries {
   297  		t.Run(query, func(t *testing.T) {
   298  			parse, _ := sqlparser.Parse(query)
   299  			_, err := Analyze(parse.(sqlparser.SelectStatement), "test", &FakeSI{})
   300  			require.Error(t, err)
   301  			require.Contains(t, err.Error(), "VT03013: not unique table/alias")
   302  		})
   303  	}
   304  }
   305  
   306  func TestMissingTable(t *testing.T) {
   307  	queries := []string{
   308  		"select t.col from a",
   309  	}
   310  
   311  	for _, query := range queries {
   312  		t.Run(query, func(t *testing.T) {
   313  			parse, _ := sqlparser.Parse(query)
   314  			st, err := Analyze(parse.(sqlparser.SelectStatement), "", &FakeSI{})
   315  			require.NoError(t, err)
   316  			require.ErrorContains(t, st.NotUnshardedErr, "symbol t.col not found")
   317  		})
   318  	}
   319  }
   320  
   321  func TestUnknownColumnMap2(t *testing.T) {
   322  	varchar := querypb.Type_VARCHAR
   323  	integer := querypb.Type_INT32
   324  
   325  	authoritativeTblA := vindexes.Table{
   326  		Name: sqlparser.NewIdentifierCS("a"),
   327  		Columns: []vindexes.Column{{
   328  			Name: sqlparser.NewIdentifierCI("col2"),
   329  			Type: varchar,
   330  		}},
   331  		ColumnListAuthoritative: true,
   332  	}
   333  	authoritativeTblB := vindexes.Table{
   334  		Name: sqlparser.NewIdentifierCS("b"),
   335  		Columns: []vindexes.Column{{
   336  			Name: sqlparser.NewIdentifierCI("col"),
   337  			Type: varchar,
   338  		}},
   339  		ColumnListAuthoritative: true,
   340  	}
   341  	nonAuthoritativeTblA := authoritativeTblA
   342  	nonAuthoritativeTblA.ColumnListAuthoritative = false
   343  	nonAuthoritativeTblB := authoritativeTblB
   344  	nonAuthoritativeTblB.ColumnListAuthoritative = false
   345  	authoritativeTblAWithConflict := vindexes.Table{
   346  		Name: sqlparser.NewIdentifierCS("a"),
   347  		Columns: []vindexes.Column{{
   348  			Name: sqlparser.NewIdentifierCI("col"),
   349  			Type: integer,
   350  		}},
   351  		ColumnListAuthoritative: true,
   352  	}
   353  	authoritativeTblBWithInt := vindexes.Table{
   354  		Name: sqlparser.NewIdentifierCS("b"),
   355  		Columns: []vindexes.Column{{
   356  			Name: sqlparser.NewIdentifierCI("col"),
   357  			Type: integer,
   358  		}},
   359  		ColumnListAuthoritative: true,
   360  	}
   361  
   362  	tests := []struct {
   363  		name   string
   364  		schema map[string]*vindexes.Table
   365  		err    bool
   366  		typ    *querypb.Type
   367  	}{{
   368  		name:   "no info about tables",
   369  		schema: map[string]*vindexes.Table{"a": {}, "b": {}},
   370  		err:    true,
   371  	}, {
   372  		name:   "non authoritative columns",
   373  		schema: map[string]*vindexes.Table{"a": &nonAuthoritativeTblA, "b": &nonAuthoritativeTblA},
   374  		err:    true,
   375  	}, {
   376  		name:   "non authoritative columns - one authoritative and one not",
   377  		schema: map[string]*vindexes.Table{"a": &nonAuthoritativeTblA, "b": &authoritativeTblB},
   378  		err:    false,
   379  		typ:    &varchar,
   380  	}, {
   381  		name:   "non authoritative columns - one authoritative and one not",
   382  		schema: map[string]*vindexes.Table{"a": &authoritativeTblA, "b": &nonAuthoritativeTblB},
   383  		err:    false,
   384  		typ:    &varchar,
   385  	}, {
   386  		name:   "authoritative columns",
   387  		schema: map[string]*vindexes.Table{"a": &authoritativeTblA, "b": &authoritativeTblB},
   388  		err:    false,
   389  		typ:    &varchar,
   390  	}, {
   391  		name:   "authoritative columns",
   392  		schema: map[string]*vindexes.Table{"a": &authoritativeTblA, "b": &authoritativeTblBWithInt},
   393  		err:    false,
   394  		typ:    &integer,
   395  	}, {
   396  		name:   "authoritative columns with overlap",
   397  		schema: map[string]*vindexes.Table{"a": &authoritativeTblAWithConflict, "b": &authoritativeTblB},
   398  		err:    true,
   399  	}}
   400  
   401  	queries := []string{"select col from a, b", "select col from a as user, b as extra"}
   402  	for _, query := range queries {
   403  		t.Run(query, func(t *testing.T) {
   404  			parse, _ := sqlparser.Parse(query)
   405  			expr := extract(parse.(*sqlparser.Select), 0)
   406  
   407  			for _, test := range tests {
   408  				t.Run(test.name, func(t *testing.T) {
   409  					si := &FakeSI{Tables: test.schema}
   410  					tbl, err := Analyze(parse.(sqlparser.SelectStatement), "", si)
   411  					if test.err {
   412  						require.True(t, err != nil || tbl.NotSingleRouteErr != nil)
   413  					} else {
   414  						require.NoError(t, err)
   415  						require.NoError(t, tbl.NotSingleRouteErr)
   416  						typ := tbl.TypeFor(expr)
   417  						assert.Equal(t, test.typ, typ)
   418  					}
   419  				})
   420  			}
   421  		})
   422  	}
   423  }
   424  
   425  func TestUnknownPredicate(t *testing.T) {
   426  	query := "select 1 from a, b where col = 1"
   427  	authoritativeTblA := &vindexes.Table{
   428  		Name: sqlparser.NewIdentifierCS("a"),
   429  	}
   430  	authoritativeTblB := &vindexes.Table{
   431  		Name: sqlparser.NewIdentifierCS("b"),
   432  	}
   433  
   434  	parse, _ := sqlparser.Parse(query)
   435  
   436  	tests := []struct {
   437  		name   string
   438  		schema map[string]*vindexes.Table
   439  		err    bool
   440  	}{
   441  		{
   442  			name:   "no info about tables",
   443  			schema: map[string]*vindexes.Table{"a": authoritativeTblA, "b": authoritativeTblB},
   444  			err:    false,
   445  		},
   446  	}
   447  	for _, test := range tests {
   448  		t.Run(test.name, func(t *testing.T) {
   449  			si := &FakeSI{Tables: test.schema}
   450  			_, err := Analyze(parse.(sqlparser.SelectStatement), "", si)
   451  			if test.err {
   452  				require.Error(t, err)
   453  			} else {
   454  				require.NoError(t, err)
   455  			}
   456  		})
   457  	}
   458  }
   459  
   460  func TestScoping(t *testing.T) {
   461  	queries := []struct {
   462  		query        string
   463  		errorMessage string
   464  	}{
   465  		{
   466  			query:        "select 1 from u1, u2 left join u3 on u1.a = u2.a",
   467  			errorMessage: "symbol u1.a not found",
   468  		},
   469  	}
   470  	for _, query := range queries {
   471  		t.Run(query.query, func(t *testing.T) {
   472  			parse, err := sqlparser.Parse(query.query)
   473  			require.NoError(t, err)
   474  			st, err := Analyze(parse.(sqlparser.SelectStatement), "user", &FakeSI{
   475  				Tables: map[string]*vindexes.Table{
   476  					"t": {Name: sqlparser.NewIdentifierCS("t")},
   477  				},
   478  			})
   479  			require.NoError(t, err)
   480  			require.EqualError(t, st.NotUnshardedErr, query.errorMessage)
   481  		})
   482  	}
   483  }
   484  
   485  func TestScopeForSubqueries(t *testing.T) {
   486  	tcases := []struct {
   487  		sql  string
   488  		deps TableSet
   489  	}{
   490  		{
   491  			sql:  `select t.col1, (select t.col2 from z as t) from x as t`,
   492  			deps: T2,
   493  		}, {
   494  			sql:  `select t.col1, (select t.col2 from z) from x as t`,
   495  			deps: T1,
   496  		}, {
   497  			sql:  `select t.col1, (select (select z.col2 from y) from z) from x as t`,
   498  			deps: T2,
   499  		}, {
   500  			sql:  `select t.col1, (select (select y.col2 from y) from z) from x as t`,
   501  			deps: None,
   502  		}, {
   503  			sql:  `select t.col1, (select (select (select (select w.col2 from w) from x) from y) from z) from x as t`,
   504  			deps: None,
   505  		}, {
   506  			sql:  `select t.col1, (select id from t) from x as t`,
   507  			deps: T2,
   508  		},
   509  	}
   510  	for _, tc := range tcases {
   511  		t.Run(tc.sql, func(t *testing.T) {
   512  			stmt, semTable := parseAndAnalyze(t, tc.sql, "d")
   513  			sel, _ := stmt.(*sqlparser.Select)
   514  
   515  			// extract the first expression from the subquery (which should be the second expression in the outer query)
   516  			sel2 := sel.SelectExprs[1].(*sqlparser.AliasedExpr).Expr.(*sqlparser.Subquery).Select.(*sqlparser.Select)
   517  			exp := extract(sel2, 0)
   518  			s1 := semTable.RecursiveDeps(exp)
   519  			require.NoError(t, semTable.NotSingleRouteErr)
   520  			// if scoping works as expected, we should be able to see the inner table being used by the inner expression
   521  			assert.Equal(t, tc.deps, s1)
   522  		})
   523  	}
   524  }
   525  
   526  func TestSubqueriesMappingWhereClause(t *testing.T) {
   527  	tcs := []struct {
   528  		sql           string
   529  		opCode        engine.PulloutOpcode
   530  		otherSideName string
   531  	}{
   532  		{
   533  			sql:           "select id from t1 where id in (select uid from t2)",
   534  			opCode:        engine.PulloutIn,
   535  			otherSideName: "id",
   536  		},
   537  		{
   538  			sql:           "select id from t1 where id not in (select uid from t2)",
   539  			opCode:        engine.PulloutNotIn,
   540  			otherSideName: "id",
   541  		},
   542  		{
   543  			sql:           "select id from t where col1 = (select uid from t2 order by uid desc limit 1)",
   544  			opCode:        engine.PulloutValue,
   545  			otherSideName: "col1",
   546  		},
   547  		{
   548  			sql:           "select id from t where exists (select uid from t2 where uid = 42)",
   549  			opCode:        engine.PulloutExists,
   550  			otherSideName: "",
   551  		},
   552  		{
   553  			sql:           "select id from t where col1 >= (select uid from t2 where uid = 42)",
   554  			opCode:        engine.PulloutValue,
   555  			otherSideName: "col1",
   556  		},
   557  	}
   558  
   559  	for i, tc := range tcs {
   560  		t.Run(fmt.Sprintf("%d_%s", i+1, tc.sql), func(t *testing.T) {
   561  			stmt, semTable := parseAndAnalyze(t, tc.sql, "d")
   562  			sel, _ := stmt.(*sqlparser.Select)
   563  
   564  			var subq *sqlparser.Subquery
   565  			switch whereExpr := sel.Where.Expr.(type) {
   566  			case *sqlparser.ComparisonExpr:
   567  				subq = whereExpr.Right.(*sqlparser.Subquery)
   568  			case *sqlparser.ExistsExpr:
   569  				subq = whereExpr.Subquery
   570  			}
   571  
   572  			extractedSubq := semTable.SubqueryRef[subq]
   573  			assert.True(t, sqlparser.Equals.Expr(extractedSubq.Subquery, subq))
   574  			assert.True(t, sqlparser.Equals.Expr(extractedSubq.Original, sel.Where.Expr))
   575  			assert.EqualValues(t, tc.opCode, extractedSubq.OpCode)
   576  			if tc.otherSideName == "" {
   577  				assert.Nil(t, extractedSubq.OtherSide)
   578  			} else {
   579  				assert.True(t, sqlparser.Equals.Expr(extractedSubq.OtherSide, sqlparser.NewColName(tc.otherSideName)))
   580  			}
   581  		})
   582  	}
   583  }
   584  
   585  func TestSubqueriesMappingSelectExprs(t *testing.T) {
   586  	tcs := []struct {
   587  		sql        string
   588  		selExprIdx int
   589  	}{
   590  		{
   591  			sql:        "select (select id from t1)",
   592  			selExprIdx: 0,
   593  		},
   594  		{
   595  			sql:        "select id, (select id from t1) from t1",
   596  			selExprIdx: 1,
   597  		},
   598  	}
   599  
   600  	for i, tc := range tcs {
   601  		t.Run(fmt.Sprintf("%d_%s", i+1, tc.sql), func(t *testing.T) {
   602  			stmt, semTable := parseAndAnalyze(t, tc.sql, "d")
   603  			sel, _ := stmt.(*sqlparser.Select)
   604  
   605  			subq := sel.SelectExprs[tc.selExprIdx].(*sqlparser.AliasedExpr).Expr.(*sqlparser.Subquery)
   606  			extractedSubq := semTable.SubqueryRef[subq]
   607  			assert.True(t, sqlparser.Equals.Expr(extractedSubq.Subquery, subq))
   608  			assert.True(t, sqlparser.Equals.Expr(extractedSubq.Original, subq))
   609  			assert.EqualValues(t, engine.PulloutValue, extractedSubq.OpCode)
   610  		})
   611  	}
   612  }
   613  
   614  func TestSubqueryOrderByBinding(t *testing.T) {
   615  	queries := []struct {
   616  		query    string
   617  		expected TableSet
   618  	}{{
   619  		query:    "select * from user u where exists (select * from user order by col)",
   620  		expected: T2,
   621  	}, {
   622  		query:    "select * from user u where exists (select * from user order by user.col)",
   623  		expected: T2,
   624  	}, {
   625  		query:    "select * from user u where exists (select * from user order by u.col)",
   626  		expected: T1,
   627  	}, {
   628  		query:    "select * from dbName.user as u where exists (select * from dbName.user order by u.col)",
   629  		expected: T1,
   630  	}, {
   631  		query:    "select * from dbName.user where exists (select * from otherDb.user order by dbName.user.col)",
   632  		expected: T1,
   633  	}, {
   634  		query:    "select id from dbName.t1 where exists (select * from dbName.t2 order by dbName.t1.id)",
   635  		expected: T1,
   636  	}}
   637  
   638  	for _, tc := range queries {
   639  		t.Run(tc.query, func(t *testing.T) {
   640  			ast, err := sqlparser.Parse(tc.query)
   641  			require.NoError(t, err)
   642  
   643  			sel := ast.(*sqlparser.Select)
   644  			st, err := Analyze(sel, "dbName", fakeSchemaInfo())
   645  			require.NoError(t, err)
   646  			exists := sel.Where.Expr.(*sqlparser.ExistsExpr)
   647  			expr := exists.Subquery.Select.(*sqlparser.Select).OrderBy[0].Expr
   648  			require.Equal(t, tc.expected, st.DirectDeps(expr))
   649  			require.Equal(t, tc.expected, st.RecursiveDeps(expr))
   650  		})
   651  	}
   652  }
   653  
   654  func TestOrderByBindingTable(t *testing.T) {
   655  	tcases := []struct {
   656  		sql  string
   657  		deps TableSet
   658  	}{{
   659  		"select col from tabl order by col",
   660  		T1,
   661  	}, {
   662  		"select tabl.col from d.tabl order by col",
   663  		T1,
   664  	}, {
   665  		"select d.tabl.col from d.tabl order by col",
   666  		T1,
   667  	}, {
   668  		"select col from tabl order by tabl.col",
   669  		T1,
   670  	}, {
   671  		"select col from tabl order by d.tabl.col",
   672  		T1,
   673  	}, {
   674  		"select col from tabl order by 1",
   675  		T1,
   676  	}, {
   677  		"select col as c from tabl order by c",
   678  		T1,
   679  	}, {
   680  		"select 1 as c from tabl order by c",
   681  		T0,
   682  	}, {
   683  		"select name, name from t1, t2 order by name",
   684  		T2,
   685  	}, {
   686  		"(select id from t1) union (select uid from t2) order by id",
   687  		MergeTableSets(T1, T2),
   688  	}, {
   689  		"select id from t1 union (select uid from t2) order by 1",
   690  		MergeTableSets(T1, T2),
   691  	}, {
   692  		"select id from t1 union select uid from t2 union (select name from t) order by 1",
   693  		MergeTableSets(T1, T2, T3),
   694  	}, {
   695  		"select a.id from t1 as a union (select uid from t2) order by 1",
   696  		MergeTableSets(T1, T2),
   697  	}, {
   698  		"select b.id as a from t1 as b union (select uid as c from t2) order by 1",
   699  		MergeTableSets(T1, T2),
   700  	}, {
   701  		"select a.id from t1 as a union (select uid from t2, t union (select name from t) order by 1) order by 1",
   702  		MergeTableSets(T1, T2, T4),
   703  	}, {
   704  		"select a.id from t1 as a union (select uid from t2, t union (select name from t) order by 1) order by id",
   705  		MergeTableSets(T1, T2, T4),
   706  	}}
   707  	for _, tc := range tcases {
   708  		t.Run(tc.sql, func(t *testing.T) {
   709  			stmt, semTable := parseAndAnalyze(t, tc.sql, "d")
   710  
   711  			var order sqlparser.Expr
   712  			switch stmt := stmt.(type) {
   713  			case *sqlparser.Select:
   714  				order = stmt.OrderBy[0].Expr
   715  			case *sqlparser.Union:
   716  				order = stmt.OrderBy[0].Expr
   717  			default:
   718  				t.Fail()
   719  			}
   720  			d := semTable.RecursiveDeps(order)
   721  			require.Equal(t, tc.deps, d, tc.sql)
   722  		})
   723  	}
   724  }
   725  
   726  func TestGroupByBinding(t *testing.T) {
   727  	tcases := []struct {
   728  		sql  string
   729  		deps TableSet
   730  	}{{
   731  		"select col from tabl group by col",
   732  		T1,
   733  	}, {
   734  		"select col from tabl group by tabl.col",
   735  		T1,
   736  	}, {
   737  		"select col from tabl group by d.tabl.col",
   738  		T1,
   739  	}, {
   740  		"select tabl.col as x from tabl group by x",
   741  		T1,
   742  	}, {
   743  		"select tabl.col as x from tabl group by col",
   744  		T1,
   745  	}, {
   746  		"select d.tabl.col as x from tabl group by x",
   747  		T1,
   748  	}, {
   749  		"select d.tabl.col as x from tabl group by col",
   750  		T1,
   751  	}, {
   752  		"select col from tabl group by 1",
   753  		T1,
   754  	}, {
   755  		"select col as c from tabl group by c",
   756  		T1,
   757  	}, {
   758  		"select 1 as c from tabl group by c",
   759  		T0,
   760  	}, {
   761  		"select t1.id from t1, t2 group by id",
   762  		T1,
   763  	}, {
   764  		"select id from t, t1 group by id",
   765  		T2,
   766  	}, {
   767  		"select id from t, t1 group by id",
   768  		T2,
   769  	}, {
   770  		"select a.id from t as a, t1 group by id",
   771  		T1,
   772  	}, {
   773  		"select a.id from t, t1 as a group by id",
   774  		T2,
   775  	}}
   776  	for _, tc := range tcases {
   777  		t.Run(tc.sql, func(t *testing.T) {
   778  			stmt, semTable := parseAndAnalyze(t, tc.sql, "d")
   779  			sel, _ := stmt.(*sqlparser.Select)
   780  			grp := sel.GroupBy[0]
   781  			d := semTable.RecursiveDeps(grp)
   782  			require.Equal(t, tc.deps, d, tc.sql)
   783  		})
   784  	}
   785  }
   786  
   787  func TestHavingBinding(t *testing.T) {
   788  	tcases := []struct {
   789  		sql  string
   790  		deps TableSet
   791  	}{{
   792  		"select col from tabl having col = 1",
   793  		T1,
   794  	}, {
   795  		"select col from tabl having tabl.col = 1",
   796  		T1,
   797  	}, {
   798  		"select col from tabl having d.tabl.col = 1",
   799  		T1,
   800  	}, {
   801  		"select tabl.col as x from tabl having x = 1",
   802  		T1,
   803  	}, {
   804  		"select tabl.col as x from tabl having col",
   805  		T1,
   806  	}, {
   807  		"select col from tabl having 1 = 1",
   808  		T0,
   809  	}, {
   810  		"select col as c from tabl having c = 1",
   811  		T1,
   812  	}, {
   813  		"select 1 as c from tabl having c = 1",
   814  		T0,
   815  	}, {
   816  		"select t1.id from t1, t2 having id = 1",
   817  		T1,
   818  	}, {
   819  		"select t.id from t, t1 having id = 1",
   820  		T1,
   821  	}, {
   822  		"select t.id, count(*) as a from t, t1 group by t.id having a = 1",
   823  		MergeTableSets(T1, T2),
   824  	}, {
   825  		"select t.id, sum(t2.name) as a from t, t2 group by t.id having a = 1",
   826  		T2,
   827  	}, {
   828  		sql:  "select u2.a, u1.a from u1, u2 having u2.a = 2",
   829  		deps: T2,
   830  	}}
   831  	for _, tc := range tcases {
   832  		t.Run(tc.sql, func(t *testing.T) {
   833  			stmt, semTable := parseAndAnalyze(t, tc.sql, "d")
   834  			sel, _ := stmt.(*sqlparser.Select)
   835  			hvng := sel.Having.Expr
   836  			d := semTable.RecursiveDeps(hvng)
   837  			require.Equal(t, tc.deps, d, tc.sql)
   838  		})
   839  	}
   840  }
   841  
   842  func TestUnionCheckFirstAndLastSelectsDeps(t *testing.T) {
   843  	query := "select col1 from tabl1 union select col2 from tabl2"
   844  
   845  	stmt, semTable := parseAndAnalyze(t, query, "")
   846  	union, _ := stmt.(*sqlparser.Union)
   847  	sel1 := union.Left.(*sqlparser.Select)
   848  	sel2 := union.Right.(*sqlparser.Select)
   849  
   850  	t1 := sel1.From[0].(*sqlparser.AliasedTableExpr)
   851  	t2 := sel2.From[0].(*sqlparser.AliasedTableExpr)
   852  	ts1 := semTable.TableSetFor(t1)
   853  	ts2 := semTable.TableSetFor(t2)
   854  	assert.Equal(t, SingleTableSet(0), ts1)
   855  	assert.Equal(t, SingleTableSet(1), ts2)
   856  
   857  	d1 := semTable.RecursiveDeps(extract(sel1, 0))
   858  	d2 := semTable.RecursiveDeps(extract(sel2, 0))
   859  	assert.Equal(t, T1, d1)
   860  	assert.Equal(t, T2, d2)
   861  }
   862  
   863  func TestUnionOrderByRewrite(t *testing.T) {
   864  	query := "select tabl1.id from tabl1 union select 1 order by 1"
   865  
   866  	stmt, _ := parseAndAnalyze(t, query, "")
   867  	assert.Equal(t, "select tabl1.id from tabl1 union select 1 from dual order by id asc", sqlparser.String(stmt))
   868  }
   869  
   870  func TestInvalidQueries(t *testing.T) {
   871  	tcases := []struct {
   872  		sql        string
   873  		err        string
   874  		shardedErr string
   875  	}{{
   876  		sql: "select t1.id, t1.col1 from t1 union select t2.uid from t2",
   877  		err: "The used SELECT statements have a different number of columns",
   878  	}, {
   879  		sql: "select t1.id from t1 union select t2.uid, t2.price from t2",
   880  		err: "The used SELECT statements have a different number of columns",
   881  	}, {
   882  		sql: "select t1.id from t1 union select t2.uid, t2.price from t2",
   883  		err: "The used SELECT statements have a different number of columns",
   884  	}, {
   885  		sql: "(select 1,2 union select 3,4) union (select 5,6 union select 7)",
   886  		err: "The used SELECT statements have a different number of columns",
   887  	}, {
   888  		sql: "select id from a union select 3 order by a.id",
   889  		err: "Table a from one of the SELECTs cannot be used in global ORDER clause",
   890  	}, {
   891  		sql: "select a.id, b.id from a, b union select 1, 2 order by id",
   892  		err: "Column 'id' in field list is ambiguous",
   893  	}, {
   894  		sql: "select sql_calc_found_rows id from a union select 1 limit 109",
   895  		err: "VT12001: unsupported: SQL_CALC_FOUND_ROWS not supported with union",
   896  	}, {
   897  		sql: "select * from (select sql_calc_found_rows id from a) as t",
   898  		err: "Incorrect usage/placement of 'SQL_CALC_FOUND_ROWS'",
   899  	}, {
   900  		sql: "select (select sql_calc_found_rows id from a) as t",
   901  		err: "Incorrect usage/placement of 'SQL_CALC_FOUND_ROWS'",
   902  	}, {
   903  		sql: "select id from t1 natural join t2",
   904  		err: "VT12001: unsupported: natural join",
   905  	}, {
   906  		sql: "select * from music where user_id IN (select sql_calc_found_rows * from music limit 10)",
   907  		err: "Incorrect usage/placement of 'SQL_CALC_FOUND_ROWS'",
   908  	}, {
   909  		sql: "select is_free_lock('xyz') from user",
   910  		err: "is_free_lock('xyz') allowed only with dual",
   911  	}, {
   912  		sql: "SELECT * FROM JSON_TABLE('[ {\"c1\": null} ]','$[*]' COLUMNS( c1 INT PATH '$.c1' ERROR ON ERROR )) as jt",
   913  		err: "VT12001: unsupported: json_table expressions",
   914  	}, {
   915  		sql:        "select does_not_exist from t1",
   916  		shardedErr: "symbol does_not_exist not found",
   917  	}, {
   918  		sql:        "select t1.does_not_exist from t1, t2",
   919  		shardedErr: "symbol t1.does_not_exist not found",
   920  	}}
   921  	for _, tc := range tcases {
   922  		t.Run(tc.sql, func(t *testing.T) {
   923  			parse, err := sqlparser.Parse(tc.sql)
   924  			require.NoError(t, err)
   925  
   926  			st, err := Analyze(parse.(sqlparser.SelectStatement), "dbName", fakeSchemaInfo())
   927  			if tc.err != "" {
   928  				require.EqualError(t, err, tc.err)
   929  			} else {
   930  				require.NoError(t, err, tc.err)
   931  				require.EqualError(t, st.NotUnshardedErr, tc.shardedErr)
   932  			}
   933  		})
   934  	}
   935  }
   936  
   937  func TestUnionWithOrderBy(t *testing.T) {
   938  	query := "select col1 from tabl1 union (select col2 from tabl2) order by 1"
   939  
   940  	stmt, semTable := parseAndAnalyze(t, query, "")
   941  	union, _ := stmt.(*sqlparser.Union)
   942  	sel1 := sqlparser.GetFirstSelect(union)
   943  	sel2 := sqlparser.GetFirstSelect(union.Right)
   944  
   945  	t1 := sel1.From[0].(*sqlparser.AliasedTableExpr)
   946  	t2 := sel2.From[0].(*sqlparser.AliasedTableExpr)
   947  	ts1 := semTable.TableSetFor(t1)
   948  	ts2 := semTable.TableSetFor(t2)
   949  	assert.Equal(t, SingleTableSet(0), ts1)
   950  	assert.Equal(t, SingleTableSet(1), ts2)
   951  
   952  	d1 := semTable.RecursiveDeps(extract(sel1, 0))
   953  	d2 := semTable.RecursiveDeps(extract(sel2, 0))
   954  	assert.Equal(t, T1, d1)
   955  	assert.Equal(t, T2, d2)
   956  }
   957  
   958  func TestScopingWDerivedTables(t *testing.T) {
   959  	queries := []struct {
   960  		query                string
   961  		errorMessage         string
   962  		recursiveExpectation TableSet
   963  		expectation          TableSet
   964  	}{
   965  		{
   966  			query:                "select id from (select x as id from user) as t",
   967  			recursiveExpectation: T1,
   968  			expectation:          T2,
   969  		}, {
   970  			query:                "select id from (select foo as id from user) as t",
   971  			recursiveExpectation: T1,
   972  			expectation:          T2,
   973  		}, {
   974  			query:                "select id from (select foo as id from (select x as foo from user) as c) as t",
   975  			recursiveExpectation: T1,
   976  			expectation:          T3,
   977  		}, {
   978  			query:                "select t.id from (select foo as id from user) as t",
   979  			recursiveExpectation: T1,
   980  			expectation:          T2,
   981  		}, {
   982  			query:        "select t.id2 from (select foo as id from user) as t",
   983  			errorMessage: "symbol t.id2 not found",
   984  		}, {
   985  			query:                "select id from (select 42 as id) as t",
   986  			recursiveExpectation: T0,
   987  			expectation:          T2,
   988  		}, {
   989  			query:                "select t.id from (select 42 as id) as t",
   990  			recursiveExpectation: T0,
   991  			expectation:          T2,
   992  		}, {
   993  			query:        "select ks.t.id from (select 42 as id) as t",
   994  			errorMessage: "symbol ks.t.id not found",
   995  		}, {
   996  			query:        "select * from (select id, id from user) as t",
   997  			errorMessage: "Duplicate column name 'id'",
   998  		}, {
   999  			query:                "select t.baz = 1 from (select id as baz from user) as t",
  1000  			expectation:          T2,
  1001  			recursiveExpectation: T1,
  1002  		}, {
  1003  			query:                "select t.id from (select * from user, music) as t",
  1004  			expectation:          T3,
  1005  			recursiveExpectation: MergeTableSets(T1, T2),
  1006  		}, {
  1007  			query:                "select t.id from (select * from user, music) as t order by t.id",
  1008  			expectation:          T3,
  1009  			recursiveExpectation: MergeTableSets(T1, T2),
  1010  		}, {
  1011  			query:                "select t.id from (select * from user) as t join user as u on t.id = u.id",
  1012  			expectation:          T2,
  1013  			recursiveExpectation: T1,
  1014  		}, {
  1015  			query:                "select t.col1 from t3 ua join (select t1.id, t1.col1 from t1 join t2) as t",
  1016  			expectation:          T4,
  1017  			recursiveExpectation: T2,
  1018  		}, {
  1019  			query:        "select uu.test from (select id from t1) uu",
  1020  			errorMessage: "symbol uu.test not found",
  1021  		}, {
  1022  			query:        "select uu.id from (select id as col from t1) uu",
  1023  			errorMessage: "symbol uu.id not found",
  1024  		}, {
  1025  			query:        "select uu.id from (select id as col from t1) uu",
  1026  			errorMessage: "symbol uu.id not found",
  1027  		}, {
  1028  			query:                "select uu.id from (select id from t1) as uu where exists (select * from t2 as uu where uu.id = uu.uid)",
  1029  			expectation:          T2,
  1030  			recursiveExpectation: T1,
  1031  		}, {
  1032  			query:                "select 1 from user uu where exists (select 1 from user where exists (select 1 from (select 1 from t1) uu where uu.user_id = uu.id))",
  1033  			expectation:          T0,
  1034  			recursiveExpectation: T0,
  1035  		}}
  1036  	for _, query := range queries {
  1037  		t.Run(query.query, func(t *testing.T) {
  1038  			parse, err := sqlparser.Parse(query.query)
  1039  			require.NoError(t, err)
  1040  			st, err := Analyze(parse.(sqlparser.SelectStatement), "user", &FakeSI{
  1041  				Tables: map[string]*vindexes.Table{
  1042  					"t": {Name: sqlparser.NewIdentifierCS("t")},
  1043  				},
  1044  			})
  1045  
  1046  			switch {
  1047  			case query.errorMessage != "" && err != nil:
  1048  				require.EqualError(t, err, query.errorMessage)
  1049  			case query.errorMessage != "":
  1050  				require.EqualError(t, st.NotUnshardedErr, query.errorMessage)
  1051  			default:
  1052  				require.NoError(t, err)
  1053  				sel := parse.(*sqlparser.Select)
  1054  				assert.Equal(t, query.recursiveExpectation, st.RecursiveDeps(extract(sel, 0)), "RecursiveDeps")
  1055  				assert.Equal(t, query.expectation, st.DirectDeps(extract(sel, 0)), "DirectDeps")
  1056  			}
  1057  		})
  1058  	}
  1059  }
  1060  
  1061  func TestDerivedTablesOrderClause(t *testing.T) {
  1062  	queries := []struct {
  1063  		query                string
  1064  		recursiveExpectation TableSet
  1065  		expectation          TableSet
  1066  	}{{
  1067  		query:                "select 1 from (select id from user) as t order by id",
  1068  		recursiveExpectation: T1,
  1069  		expectation:          T2,
  1070  	}, {
  1071  		query:                "select id from (select id from user) as t order by id",
  1072  		recursiveExpectation: T1,
  1073  		expectation:          T2,
  1074  	}, {
  1075  		query:                "select id from (select id from user) as t order by t.id",
  1076  		recursiveExpectation: T1,
  1077  		expectation:          T2,
  1078  	}, {
  1079  		query:                "select id as foo from (select id from user) as t order by foo",
  1080  		recursiveExpectation: T1,
  1081  		expectation:          T2,
  1082  	}, {
  1083  		query:                "select bar from (select id as bar from user) as t order by bar",
  1084  		recursiveExpectation: T1,
  1085  		expectation:          T2,
  1086  	}, {
  1087  		query:                "select bar as foo from (select id as bar from user) as t order by bar",
  1088  		recursiveExpectation: T1,
  1089  		expectation:          T2,
  1090  	}, {
  1091  		query:                "select bar as foo from (select id as bar from user) as t order by foo",
  1092  		recursiveExpectation: T1,
  1093  		expectation:          T2,
  1094  	}, {
  1095  		query:                "select bar as foo from (select id as bar, oo from user) as t order by oo",
  1096  		recursiveExpectation: T1,
  1097  		expectation:          T2,
  1098  	}, {
  1099  		query:                "select bar as foo from (select id, oo from user) as t(bar,oo) order by bar",
  1100  		recursiveExpectation: T1,
  1101  		expectation:          T2,
  1102  	}}
  1103  	si := &FakeSI{Tables: map[string]*vindexes.Table{"t": {Name: sqlparser.NewIdentifierCS("t")}}}
  1104  	for _, query := range queries {
  1105  		t.Run(query.query, func(t *testing.T) {
  1106  			parse, err := sqlparser.Parse(query.query)
  1107  			require.NoError(t, err)
  1108  
  1109  			st, err := Analyze(parse.(sqlparser.SelectStatement), "user", si)
  1110  			require.NoError(t, err)
  1111  
  1112  			sel := parse.(*sqlparser.Select)
  1113  			assert.Equal(t, query.recursiveExpectation, st.RecursiveDeps(sel.OrderBy[0].Expr), "RecursiveDeps")
  1114  			assert.Equal(t, query.expectation, st.DirectDeps(sel.OrderBy[0].Expr), "DirectDeps")
  1115  
  1116  		})
  1117  	}
  1118  }
  1119  
  1120  func TestScopingWComplexDerivedTables(t *testing.T) {
  1121  	queries := []struct {
  1122  		query            string
  1123  		errorMessage     string
  1124  		rightExpectation TableSet
  1125  		leftExpectation  TableSet
  1126  	}{
  1127  		{
  1128  			query:            "select 1 from user uu where exists (select 1 from user where exists (select 1 from (select 1 from t1) uu where uu.user_id = uu.id))",
  1129  			rightExpectation: T1,
  1130  			leftExpectation:  T1,
  1131  		},
  1132  		{
  1133  			query:            "select 1 from user.user uu where exists (select 1 from user.user as uu where exists (select 1 from (select 1 from user.t1) uu where uu.user_id = uu.id))",
  1134  			rightExpectation: T2,
  1135  			leftExpectation:  T2,
  1136  		},
  1137  	}
  1138  	for _, query := range queries {
  1139  		t.Run(query.query, func(t *testing.T) {
  1140  			parse, err := sqlparser.Parse(query.query)
  1141  			require.NoError(t, err)
  1142  			st, err := Analyze(parse.(sqlparser.SelectStatement), "user", &FakeSI{
  1143  				Tables: map[string]*vindexes.Table{
  1144  					"t": {Name: sqlparser.NewIdentifierCS("t")},
  1145  				},
  1146  			})
  1147  			if query.errorMessage != "" {
  1148  				require.EqualError(t, err, query.errorMessage)
  1149  			} else {
  1150  				require.NoError(t, err)
  1151  				sel := parse.(*sqlparser.Select)
  1152  				comparisonExpr := sel.Where.Expr.(*sqlparser.ExistsExpr).Subquery.Select.(*sqlparser.Select).Where.Expr.(*sqlparser.ExistsExpr).Subquery.Select.(*sqlparser.Select).Where.Expr.(*sqlparser.ComparisonExpr)
  1153  				left := comparisonExpr.Left
  1154  				right := comparisonExpr.Right
  1155  				assert.Equal(t, query.leftExpectation, st.RecursiveDeps(left), "Left RecursiveDeps")
  1156  				assert.Equal(t, query.rightExpectation, st.RecursiveDeps(right), "Right RecursiveDeps")
  1157  			}
  1158  		})
  1159  	}
  1160  }
  1161  
  1162  func TestScopingWVindexTables(t *testing.T) {
  1163  	queries := []struct {
  1164  		query                string
  1165  		errorMessage         string
  1166  		recursiveExpectation TableSet
  1167  		expectation          TableSet
  1168  	}{
  1169  		{
  1170  			query:                "select id from user_index where id = 1",
  1171  			recursiveExpectation: T1,
  1172  			expectation:          T1,
  1173  		}, {
  1174  			query:                "select u.id + t.id from t as t join user_index as u where u.id = 1 and u.id = t.id",
  1175  			recursiveExpectation: MergeTableSets(T1, T2),
  1176  			expectation:          MergeTableSets(T1, T2),
  1177  		},
  1178  	}
  1179  	for _, query := range queries {
  1180  		t.Run(query.query, func(t *testing.T) {
  1181  			parse, err := sqlparser.Parse(query.query)
  1182  			require.NoError(t, err)
  1183  			hash, _ := vindexes.NewHash("user_index", nil)
  1184  			st, err := Analyze(parse.(sqlparser.SelectStatement), "user", &FakeSI{
  1185  				Tables: map[string]*vindexes.Table{
  1186  					"t": {Name: sqlparser.NewIdentifierCS("t")},
  1187  				},
  1188  				VindexTables: map[string]vindexes.Vindex{
  1189  					"user_index": hash,
  1190  				},
  1191  			})
  1192  			if query.errorMessage != "" {
  1193  				require.EqualError(t, err, query.errorMessage)
  1194  			} else {
  1195  				require.NoError(t, err)
  1196  				sel := parse.(*sqlparser.Select)
  1197  				assert.Equal(t, query.recursiveExpectation, st.RecursiveDeps(extract(sel, 0)))
  1198  				assert.Equal(t, query.expectation, st.DirectDeps(extract(sel, 0)))
  1199  			}
  1200  		})
  1201  	}
  1202  }
  1203  
  1204  func BenchmarkAnalyzeMultipleDifferentQueries(b *testing.B) {
  1205  	queries := []string{
  1206  		"select col from tabl",
  1207  		"select t.col from t, s",
  1208  		"select max(tabl.col1 + tabl.col2) from d.X as tabl",
  1209  		"select max(X.col + S.col) from t as X, s as S",
  1210  		"select case t.col when s.col then r.col else u.col end from t, s, r, w, u",
  1211  		"select t.col1, (select t.col2 from z as t) from x as t",
  1212  		"select * from user u where exists (select * from user order by col)",
  1213  		"select id from dbName.t1 where exists (select * from dbName.t2 order by dbName.t1.id)",
  1214  		"select d.tabl.col from d.tabl order by col",
  1215  		"select a.id from t1 as a union (select uid from t2, t union (select name from t) order by 1) order by 1",
  1216  		"select a.id from t, t1 as a group by id",
  1217  		"select tabl.col as x from tabl having x = 1",
  1218  		"select id from (select foo as id from (select x as foo from user) as c) as t",
  1219  	}
  1220  
  1221  	for i := 0; i < b.N; i++ {
  1222  		for _, query := range queries {
  1223  			parse, err := sqlparser.Parse(query)
  1224  			require.NoError(b, err)
  1225  
  1226  			_, _ = Analyze(parse.(sqlparser.SelectStatement), "d", fakeSchemaInfo())
  1227  		}
  1228  	}
  1229  }
  1230  
  1231  func BenchmarkAnalyzeUnionQueries(b *testing.B) {
  1232  	queries := []string{
  1233  		"select id from t1 union select uid from t2",
  1234  		"select col1 from tabl1 union (select col2 from tabl2)",
  1235  		"select t1.id, t1.col1 from t1 union select t2.uid from t2",
  1236  		"select a.id from t1 as a union (select uid from t2, t union (select name from t) order by 1) order by 1",
  1237  		"select b.id as a from t1 as b union (select uid as c from t2) order by 1",
  1238  		"select a.id from t1 as a union (select uid from t2) order by 1",
  1239  		"select id from t1 union select uid from t2 union (select name from t)",
  1240  		"select id from t1 union (select uid from t2) order by 1",
  1241  		"(select id from t1) union (select uid from t2) order by id",
  1242  		"select a.id from t1 as a union (select uid from t2, t union (select name from t) order by 1) order by 1",
  1243  	}
  1244  
  1245  	for i := 0; i < b.N; i++ {
  1246  		for _, query := range queries {
  1247  			parse, err := sqlparser.Parse(query)
  1248  			require.NoError(b, err)
  1249  
  1250  			_, _ = Analyze(parse.(sqlparser.SelectStatement), "d", fakeSchemaInfo())
  1251  		}
  1252  	}
  1253  }
  1254  
  1255  func BenchmarkAnalyzeSubQueries(b *testing.B) {
  1256  	queries := []string{
  1257  		"select * from user u where exists (select * from user order by col)",
  1258  		"select * from user u where exists (select * from user order by user.col)",
  1259  		"select * from user u where exists (select * from user order by u.col)",
  1260  		"select * from dbName.user as u where exists (select * from dbName.user order by u.col)",
  1261  		"select * from dbName.user where exists (select * from otherDb.user order by dbName.user.col)",
  1262  		"select id from dbName.t1 where exists (select * from dbName.t2 order by dbName.t1.id)",
  1263  		"select t.col1, (select t.col2 from z as t) from x as t",
  1264  		"select t.col1, (select t.col2 from z) from x as t",
  1265  		"select t.col1, (select (select z.col2 from y) from z) from x as t",
  1266  		"select t.col1, (select (select y.col2 from y) from z) from x as t",
  1267  		"select t.col1, (select (select (select (select w.col2 from w) from x) from y) from z) from x as t",
  1268  		"select t.col1, (select id from t) from x as t",
  1269  	}
  1270  
  1271  	for i := 0; i < b.N; i++ {
  1272  		for _, query := range queries {
  1273  			parse, err := sqlparser.Parse(query)
  1274  			require.NoError(b, err)
  1275  
  1276  			_, _ = Analyze(parse.(sqlparser.SelectStatement), "d", fakeSchemaInfo())
  1277  		}
  1278  	}
  1279  }
  1280  
  1281  func BenchmarkAnalyzeDerivedTableQueries(b *testing.B) {
  1282  	queries := []string{
  1283  		"select id from (select x as id from user) as t",
  1284  		"select id from (select foo as id from user) as t",
  1285  		"select id from (select foo as id from (select x as foo from user) as c) as t",
  1286  		"select t.id from (select foo as id from user) as t",
  1287  		"select t.id2 from (select foo as id from user) as t",
  1288  		"select id from (select 42 as id) as t",
  1289  		"select t.id from (select 42 as id) as t",
  1290  		"select ks.t.id from (select 42 as id) as t",
  1291  		"select * from (select id, id from user) as t",
  1292  		"select t.baz = 1 from (select id as baz from user) as t",
  1293  		"select t.id from (select * from user, music) as t",
  1294  		"select t.id from (select * from user, music) as t order by t.id",
  1295  		"select t.id from (select * from user) as t join user as u on t.id = u.id",
  1296  		"select t.col1 from t3 ua join (select t1.id, t1.col1 from t1 join t2) as t",
  1297  		"select uu.id from (select id from t1) as uu where exists (select * from t2 as uu where uu.id = uu.uid)",
  1298  		"select 1 from user uu where exists (select 1 from user where exists (select 1 from (select 1 from t1) uu where uu.user_id = uu.id))",
  1299  	}
  1300  
  1301  	for i := 0; i < b.N; i++ {
  1302  		for _, query := range queries {
  1303  			parse, err := sqlparser.Parse(query)
  1304  			require.NoError(b, err)
  1305  
  1306  			_, _ = Analyze(parse.(sqlparser.SelectStatement), "d", fakeSchemaInfo())
  1307  		}
  1308  	}
  1309  }
  1310  
  1311  func BenchmarkAnalyzeHavingQueries(b *testing.B) {
  1312  	queries := []string{
  1313  		"select col from tabl having col = 1",
  1314  		"select col from tabl having tabl.col = 1",
  1315  		"select col from tabl having d.tabl.col = 1",
  1316  		"select tabl.col as x from tabl having x = 1",
  1317  		"select tabl.col as x from tabl having col",
  1318  		"select col from tabl having 1 = 1",
  1319  		"select col as c from tabl having c = 1",
  1320  		"select 1 as c from tabl having c = 1",
  1321  		"select t1.id from t1, t2 having id = 1",
  1322  		"select t.id from t, t1 having id = 1",
  1323  		"select t.id, count(*) as a from t, t1 group by t.id having a = 1",
  1324  		"select u2.a, u1.a from u1, u2 having u2.a = 2",
  1325  	}
  1326  
  1327  	for i := 0; i < b.N; i++ {
  1328  		for _, query := range queries {
  1329  			parse, err := sqlparser.Parse(query)
  1330  			require.NoError(b, err)
  1331  
  1332  			_, _ = Analyze(parse.(sqlparser.SelectStatement), "d", fakeSchemaInfo())
  1333  		}
  1334  	}
  1335  }
  1336  
  1337  func BenchmarkAnalyzeGroupByQueries(b *testing.B) {
  1338  	queries := []string{
  1339  		"select col from tabl group by col",
  1340  		"select col from tabl group by tabl.col",
  1341  		"select col from tabl group by d.tabl.col",
  1342  		"select tabl.col as x from tabl group by x",
  1343  		"select tabl.col as x from tabl group by col",
  1344  		"select d.tabl.col as x from tabl group by x",
  1345  		"select d.tabl.col as x from tabl group by col",
  1346  		"select col from tabl group by 1",
  1347  		"select col as c from tabl group by c",
  1348  		"select 1 as c from tabl group by c",
  1349  		"select t1.id from t1, t2 group by id",
  1350  		"select id from t, t1 group by id",
  1351  		"select id from t, t1 group by id",
  1352  		"select a.id from t as a, t1 group by id",
  1353  		"select a.id from t, t1 as a group by id",
  1354  	}
  1355  
  1356  	for i := 0; i < b.N; i++ {
  1357  		for _, query := range queries {
  1358  			parse, err := sqlparser.Parse(query)
  1359  			require.NoError(b, err)
  1360  
  1361  			_, _ = Analyze(parse.(sqlparser.SelectStatement), "d", fakeSchemaInfo())
  1362  		}
  1363  	}
  1364  }
  1365  
  1366  func BenchmarkAnalyzeOrderByQueries(b *testing.B) {
  1367  	queries := []string{
  1368  		"select col from tabl order by col",
  1369  		"select tabl.col from d.tabl order by col",
  1370  		"select d.tabl.col from d.tabl order by col",
  1371  		"select col from tabl order by tabl.col",
  1372  		"select col from tabl order by d.tabl.col",
  1373  		"select col from tabl order by 1",
  1374  		"select col as c from tabl order by c",
  1375  		"select 1 as c from tabl order by c",
  1376  		"select name, name from t1, t2 order by name",
  1377  	}
  1378  
  1379  	for i := 0; i < b.N; i++ {
  1380  		for _, query := range queries {
  1381  			parse, err := sqlparser.Parse(query)
  1382  			require.NoError(b, err)
  1383  
  1384  			_, _ = Analyze(parse.(sqlparser.SelectStatement), "d", fakeSchemaInfo())
  1385  		}
  1386  	}
  1387  }
  1388  
  1389  func parseAndAnalyze(t *testing.T, query, dbName string) (sqlparser.Statement, *SemTable) {
  1390  	t.Helper()
  1391  	parse, err := sqlparser.Parse(query)
  1392  	require.NoError(t, err)
  1393  
  1394  	semTable, err := Analyze(parse, dbName, fakeSchemaInfo())
  1395  	require.NoError(t, err)
  1396  	return parse, semTable
  1397  }
  1398  
  1399  func TestSingleUnshardedKeyspace(t *testing.T) {
  1400  	tests := []struct {
  1401  		query     string
  1402  		unsharded *vindexes.Keyspace
  1403  		tables    []*vindexes.Table
  1404  	}{
  1405  		{
  1406  			query:     "select 1 from t, t1",
  1407  			unsharded: nil, // both tables are unsharded, but from different keyspaces
  1408  			tables:    nil,
  1409  		}, {
  1410  			query:     "select 1 from t2",
  1411  			unsharded: nil,
  1412  			tables:    nil,
  1413  		}, {
  1414  			query:     "select 1 from t, t2",
  1415  			unsharded: nil,
  1416  			tables:    nil,
  1417  		}, {
  1418  			query:     "select 1 from t as A, t as B",
  1419  			unsharded: ks1,
  1420  			tables: []*vindexes.Table{
  1421  				{Keyspace: ks1, Name: sqlparser.NewIdentifierCS("t")},
  1422  				{Keyspace: ks1, Name: sqlparser.NewIdentifierCS("t")},
  1423  			},
  1424  		},
  1425  	}
  1426  
  1427  	for _, test := range tests {
  1428  		t.Run(test.query, func(t *testing.T) {
  1429  			_, semTable := parseAndAnalyze(t, test.query, "d")
  1430  			queryIsUnsharded, tables := semTable.SingleUnshardedKeyspace()
  1431  			assert.Equal(t, test.unsharded, queryIsUnsharded)
  1432  			assert.Equal(t, test.tables, tables)
  1433  		})
  1434  	}
  1435  }
  1436  
  1437  // TestScopingSubQueryJoinClause tests the scoping behavior of a subquery containing a join clause.
  1438  // The test ensures that the scoping analysis correctly identifies and handles the relationships
  1439  // between the tables involved in the join operation with the outer query.
  1440  func TestScopingSubQueryJoinClause(t *testing.T) {
  1441  	query := "select (select 1 from u1 join u2 on u1.id = u2.id and u2.id = u3.id) x from u3"
  1442  
  1443  	parse, err := sqlparser.Parse(query)
  1444  	require.NoError(t, err)
  1445  
  1446  	st, err := Analyze(parse, "user", &FakeSI{
  1447  		Tables: map[string]*vindexes.Table{
  1448  			"t": {Name: sqlparser.NewIdentifierCS("t")},
  1449  		},
  1450  	})
  1451  	require.NoError(t, err)
  1452  	require.NoError(t, st.NotUnshardedErr)
  1453  
  1454  	tb := st.DirectDeps(parse.(*sqlparser.Select).SelectExprs[0].(*sqlparser.AliasedExpr).Expr.(*sqlparser.Subquery).Select.(*sqlparser.Select).From[0].(*sqlparser.JoinTableExpr).Condition.On)
  1455  	require.Equal(t, 3, tb.NumberOfTables())
  1456  
  1457  }
  1458  
  1459  var ks1 = &vindexes.Keyspace{
  1460  	Name:    "ks1",
  1461  	Sharded: false,
  1462  }
  1463  var ks2 = &vindexes.Keyspace{
  1464  	Name:    "ks2",
  1465  	Sharded: false,
  1466  }
  1467  var ks3 = &vindexes.Keyspace{
  1468  	Name:    "ks3",
  1469  	Sharded: true,
  1470  }
  1471  
  1472  func fakeSchemaInfo() *FakeSI {
  1473  	cols1 := []vindexes.Column{{
  1474  		Name: sqlparser.NewIdentifierCI("id"),
  1475  		Type: querypb.Type_INT64,
  1476  	}}
  1477  	cols2 := []vindexes.Column{{
  1478  		Name: sqlparser.NewIdentifierCI("uid"),
  1479  		Type: querypb.Type_INT64,
  1480  	}, {
  1481  		Name: sqlparser.NewIdentifierCI("name"),
  1482  		Type: querypb.Type_VARCHAR,
  1483  	}}
  1484  
  1485  	si := &FakeSI{
  1486  		Tables: map[string]*vindexes.Table{
  1487  			"t":  {Name: sqlparser.NewIdentifierCS("t"), Keyspace: ks1},
  1488  			"t1": {Name: sqlparser.NewIdentifierCS("t1"), Columns: cols1, ColumnListAuthoritative: true, Keyspace: ks2},
  1489  			"t2": {Name: sqlparser.NewIdentifierCS("t2"), Columns: cols2, ColumnListAuthoritative: true, Keyspace: ks3},
  1490  		},
  1491  	}
  1492  	return si
  1493  }