vitess.io/vitess@v0.16.2/go/vt/vtgate/planbuilder/rewrite_test.go (about)

     1  /*
     2  Copyright 2021 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package planbuilder
    18  
    19  import (
    20  	"testing"
    21  
    22  	"github.com/stretchr/testify/assert"
    23  	"github.com/stretchr/testify/require"
    24  
    25  	"vitess.io/vitess/go/vt/sqlparser"
    26  	"vitess.io/vitess/go/vt/vtgate/semantics"
    27  )
    28  
    29  func TestSubqueryRewrite(t *testing.T) {
    30  	tcases := []struct {
    31  		input  string
    32  		output string
    33  	}{{
    34  		input:  "select 1 from t1",
    35  		output: "select 1 from t1",
    36  	}, {
    37  		input:  "select (select 1) from t1",
    38  		output: "select :__sq1 from t1",
    39  	}, {
    40  		input:  "select 1 from t1 where exists (select 1)",
    41  		output: "select 1 from t1 where :__sq_has_values1",
    42  	}, {
    43  		input:  "select id from t1 where id in (select 1)",
    44  		output: "select id from t1 where :__sq_has_values1 = 1 and id in ::__sq1",
    45  	}, {
    46  		input:  "select id from t1 where id not in (select 1)",
    47  		output: "select id from t1 where :__sq_has_values1 = 0 or id not in ::__sq1",
    48  	}, {
    49  		input:  "select id from t1 where id = (select 1)",
    50  		output: "select id from t1 where id = :__sq1",
    51  	}, {
    52  		input:  "select id from t1 where id >= (select 1)",
    53  		output: "select id from t1 where id >= :__sq1",
    54  	}, {
    55  		input:  "select id from t1 where t1.id = (select 1 from t2 where t2.id = t1.id)",
    56  		output: "select id from t1 where t1.id = :__sq1",
    57  	}, {
    58  		input:  "select id from t1 join t2 where t1.id = t2.id and exists (select 1)",
    59  		output: "select id from t1 join t2 where t1.id = t2.id and :__sq_has_values1",
    60  	}, {
    61  		input:  "select id from t1 where not exists (select 1)",
    62  		output: "select id from t1 where not :__sq_has_values1",
    63  	}, {
    64  		input:  "select id from t1 where not exists (select 1) and exists (select 2)",
    65  		output: "select id from t1 where not :__sq_has_values1 and :__sq_has_values2",
    66  	}, {
    67  		input:  "select (select 1), (select 2) from t1 join t2 on t1.id = (select 1) where t1.id in (select 1)",
    68  		output: "select :__sq2, :__sq3 from t1 join t2 on t1.id = :__sq1 where :__sq_has_values4 = 1 and t1.id in ::__sq4",
    69  	}}
    70  	for _, tcase := range tcases {
    71  		t.Run(tcase.input, func(t *testing.T) {
    72  			ast, vars, err := sqlparser.Parse2(tcase.input)
    73  			require.NoError(t, err)
    74  			reservedVars := sqlparser.NewReservedVars("vtg", vars)
    75  			selectStatement, isSelectStatement := ast.(*sqlparser.Select)
    76  			require.True(t, isSelectStatement, "analyzer expects a select statement")
    77  			semTable, err := semantics.Analyze(selectStatement, "", &semantics.FakeSI{})
    78  			require.NoError(t, err)
    79  			err = queryRewrite(semTable, reservedVars, selectStatement)
    80  			require.NoError(t, err)
    81  			assert.Equal(t, tcase.output, sqlparser.String(selectStatement))
    82  		})
    83  	}
    84  }
    85  
    86  func TestHavingRewrite(t *testing.T) {
    87  	tcases := []struct {
    88  		input  string
    89  		output string
    90  		sqs    map[string]string
    91  	}{{
    92  		input:  "select 1 from t1 having a = 1",
    93  		output: "select 1 from t1 where a = 1",
    94  	}, {
    95  		input:  "select 1 from t1 where x = 1 and y = 2 having a = 1",
    96  		output: "select 1 from t1 where x = 1 and y = 2 and a = 1",
    97  	}, {
    98  		input:  "select 1 from t1 where x = 1 or y = 2 having a = 1",
    99  		output: "select 1 from t1 where (x = 1 or y = 2) and a = 1",
   100  	}, {
   101  		input:  "select 1 from t1 where x = 1 having a = 1 and b = 2",
   102  		output: "select 1 from t1 where x = 1 and a = 1 and b = 2",
   103  	}, {
   104  		input:  "select 1 from t1 where x = 1 having a = 1 or b = 2",
   105  		output: "select 1 from t1 where x = 1 and (a = 1 or b = 2)",
   106  	}, {
   107  		input:  "select 1 from t1 where x = 1 and y = 2 having a = 1 and b = 2",
   108  		output: "select 1 from t1 where x = 1 and y = 2 and a = 1 and b = 2",
   109  	}, {
   110  		input:  "select 1 from t1 where x = 1 or y = 2 having a = 1 and b = 2",
   111  		output: "select 1 from t1 where (x = 1 or y = 2) and a = 1 and b = 2",
   112  	}, {
   113  		input:  "select 1 from t1 where x = 1 and y = 2 having a = 1 or b = 2",
   114  		output: "select 1 from t1 where x = 1 and y = 2 and (a = 1 or b = 2)",
   115  	}, {
   116  		input:  "select 1 from t1 where x = 1 or y = 2 having a = 1 or b = 2",
   117  		output: "select 1 from t1 where (x = 1 or y = 2) and (a = 1 or b = 2)",
   118  	}, {
   119  		input:  "select 1 from t1 where x = 1 or y = 2 having a = 1 and count(*) = 1",
   120  		output: "select 1 from t1 where (x = 1 or y = 2) and a = 1 having count(*) = 1",
   121  	}, {
   122  		input:  "select count(*) k from t1 where x = 1 or y = 2 having a = 1 and k = 1",
   123  		output: "select count(*) as k from t1 where (x = 1 or y = 2) and a = 1 having count(*) = 1",
   124  	}, {
   125  		input:  "select count(*) k from t1 having k = 10",
   126  		output: "select count(*) as k from t1 having count(*) = 10",
   127  	}, {
   128  		input:  "select 1 from t1 where x in (select 1 from t2 having a = 1)",
   129  		output: "select 1 from t1 where :__sq_has_values1 = 1 and x in ::__sq1",
   130  		sqs:    map[string]string{"__sq1": "select 1 from t2 where a = 1"},
   131  	}, {input: "select 1 from t1 group by a having a = 1 and count(*) > 1",
   132  		output: "select 1 from t1 where a = 1 group by a having count(*) > 1",
   133  	}}
   134  	for _, tcase := range tcases {
   135  		t.Run(tcase.input, func(t *testing.T) {
   136  			semTable, reservedVars, sel := prepTest(t, tcase.input)
   137  			err := queryRewrite(semTable, reservedVars, sel)
   138  			require.NoError(t, err)
   139  			assert.Equal(t, tcase.output, sqlparser.String(sel))
   140  			squeries, found := semTable.SubqueryMap[sel]
   141  			if len(tcase.sqs) > 0 {
   142  				assert.True(t, found, "no subquery found in the query")
   143  				assert.Equal(t, len(tcase.sqs), len(squeries), "number of subqueries not matched")
   144  			}
   145  			for _, sq := range squeries {
   146  				assert.Equal(t, tcase.sqs[sq.GetArgName()], sqlparser.String(sq.Subquery.Select))
   147  			}
   148  		})
   149  	}
   150  }
   151  
   152  func prepTest(t *testing.T, sql string) (*semantics.SemTable, *sqlparser.ReservedVars, *sqlparser.Select) {
   153  	ast, vars, err := sqlparser.Parse2(sql)
   154  	require.NoError(t, err)
   155  
   156  	sel, isSelectStatement := ast.(*sqlparser.Select)
   157  	require.True(t, isSelectStatement, "analyzer expects a select statement")
   158  
   159  	reservedVars := sqlparser.NewReservedVars("vtg", vars)
   160  	semTable, err := semantics.Analyze(sel, "", &semantics.FakeSI{})
   161  	require.NoError(t, err)
   162  
   163  	return semTable, reservedVars, sel
   164  }