vitess.io/vitess@v0.16.2/go/vt/vtgate/planbuilder/rewrite_test.go (about) 1 /* 2 Copyright 2021 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package planbuilder 18 19 import ( 20 "testing" 21 22 "github.com/stretchr/testify/assert" 23 "github.com/stretchr/testify/require" 24 25 "vitess.io/vitess/go/vt/sqlparser" 26 "vitess.io/vitess/go/vt/vtgate/semantics" 27 ) 28 29 func TestSubqueryRewrite(t *testing.T) { 30 tcases := []struct { 31 input string 32 output string 33 }{{ 34 input: "select 1 from t1", 35 output: "select 1 from t1", 36 }, { 37 input: "select (select 1) from t1", 38 output: "select :__sq1 from t1", 39 }, { 40 input: "select 1 from t1 where exists (select 1)", 41 output: "select 1 from t1 where :__sq_has_values1", 42 }, { 43 input: "select id from t1 where id in (select 1)", 44 output: "select id from t1 where :__sq_has_values1 = 1 and id in ::__sq1", 45 }, { 46 input: "select id from t1 where id not in (select 1)", 47 output: "select id from t1 where :__sq_has_values1 = 0 or id not in ::__sq1", 48 }, { 49 input: "select id from t1 where id = (select 1)", 50 output: "select id from t1 where id = :__sq1", 51 }, { 52 input: "select id from t1 where id >= (select 1)", 53 output: "select id from t1 where id >= :__sq1", 54 }, { 55 input: "select id from t1 where t1.id = (select 1 from t2 where t2.id = t1.id)", 56 output: "select id from t1 where t1.id = :__sq1", 57 }, { 58 input: "select id from t1 join t2 where t1.id = t2.id and exists (select 1)", 59 output: "select id from t1 join t2 where t1.id = t2.id and :__sq_has_values1", 60 }, { 61 input: "select id from t1 where not exists (select 1)", 62 output: "select id from t1 where not :__sq_has_values1", 63 }, { 64 input: "select id from t1 where not exists (select 1) and exists (select 2)", 65 output: "select id from t1 where not :__sq_has_values1 and :__sq_has_values2", 66 }, { 67 input: "select (select 1), (select 2) from t1 join t2 on t1.id = (select 1) where t1.id in (select 1)", 68 output: "select :__sq2, :__sq3 from t1 join t2 on t1.id = :__sq1 where :__sq_has_values4 = 1 and t1.id in ::__sq4", 69 }} 70 for _, tcase := range tcases { 71 t.Run(tcase.input, func(t *testing.T) { 72 ast, vars, err := sqlparser.Parse2(tcase.input) 73 require.NoError(t, err) 74 reservedVars := sqlparser.NewReservedVars("vtg", vars) 75 selectStatement, isSelectStatement := ast.(*sqlparser.Select) 76 require.True(t, isSelectStatement, "analyzer expects a select statement") 77 semTable, err := semantics.Analyze(selectStatement, "", &semantics.FakeSI{}) 78 require.NoError(t, err) 79 err = queryRewrite(semTable, reservedVars, selectStatement) 80 require.NoError(t, err) 81 assert.Equal(t, tcase.output, sqlparser.String(selectStatement)) 82 }) 83 } 84 } 85 86 func TestHavingRewrite(t *testing.T) { 87 tcases := []struct { 88 input string 89 output string 90 sqs map[string]string 91 }{{ 92 input: "select 1 from t1 having a = 1", 93 output: "select 1 from t1 where a = 1", 94 }, { 95 input: "select 1 from t1 where x = 1 and y = 2 having a = 1", 96 output: "select 1 from t1 where x = 1 and y = 2 and a = 1", 97 }, { 98 input: "select 1 from t1 where x = 1 or y = 2 having a = 1", 99 output: "select 1 from t1 where (x = 1 or y = 2) and a = 1", 100 }, { 101 input: "select 1 from t1 where x = 1 having a = 1 and b = 2", 102 output: "select 1 from t1 where x = 1 and a = 1 and b = 2", 103 }, { 104 input: "select 1 from t1 where x = 1 having a = 1 or b = 2", 105 output: "select 1 from t1 where x = 1 and (a = 1 or b = 2)", 106 }, { 107 input: "select 1 from t1 where x = 1 and y = 2 having a = 1 and b = 2", 108 output: "select 1 from t1 where x = 1 and y = 2 and a = 1 and b = 2", 109 }, { 110 input: "select 1 from t1 where x = 1 or y = 2 having a = 1 and b = 2", 111 output: "select 1 from t1 where (x = 1 or y = 2) and a = 1 and b = 2", 112 }, { 113 input: "select 1 from t1 where x = 1 and y = 2 having a = 1 or b = 2", 114 output: "select 1 from t1 where x = 1 and y = 2 and (a = 1 or b = 2)", 115 }, { 116 input: "select 1 from t1 where x = 1 or y = 2 having a = 1 or b = 2", 117 output: "select 1 from t1 where (x = 1 or y = 2) and (a = 1 or b = 2)", 118 }, { 119 input: "select 1 from t1 where x = 1 or y = 2 having a = 1 and count(*) = 1", 120 output: "select 1 from t1 where (x = 1 or y = 2) and a = 1 having count(*) = 1", 121 }, { 122 input: "select count(*) k from t1 where x = 1 or y = 2 having a = 1 and k = 1", 123 output: "select count(*) as k from t1 where (x = 1 or y = 2) and a = 1 having count(*) = 1", 124 }, { 125 input: "select count(*) k from t1 having k = 10", 126 output: "select count(*) as k from t1 having count(*) = 10", 127 }, { 128 input: "select 1 from t1 where x in (select 1 from t2 having a = 1)", 129 output: "select 1 from t1 where :__sq_has_values1 = 1 and x in ::__sq1", 130 sqs: map[string]string{"__sq1": "select 1 from t2 where a = 1"}, 131 }, {input: "select 1 from t1 group by a having a = 1 and count(*) > 1", 132 output: "select 1 from t1 where a = 1 group by a having count(*) > 1", 133 }} 134 for _, tcase := range tcases { 135 t.Run(tcase.input, func(t *testing.T) { 136 semTable, reservedVars, sel := prepTest(t, tcase.input) 137 err := queryRewrite(semTable, reservedVars, sel) 138 require.NoError(t, err) 139 assert.Equal(t, tcase.output, sqlparser.String(sel)) 140 squeries, found := semTable.SubqueryMap[sel] 141 if len(tcase.sqs) > 0 { 142 assert.True(t, found, "no subquery found in the query") 143 assert.Equal(t, len(tcase.sqs), len(squeries), "number of subqueries not matched") 144 } 145 for _, sq := range squeries { 146 assert.Equal(t, tcase.sqs[sq.GetArgName()], sqlparser.String(sq.Subquery.Select)) 147 } 148 }) 149 } 150 } 151 152 func prepTest(t *testing.T, sql string) (*semantics.SemTable, *sqlparser.ReservedVars, *sqlparser.Select) { 153 ast, vars, err := sqlparser.Parse2(sql) 154 require.NoError(t, err) 155 156 sel, isSelectStatement := ast.(*sqlparser.Select) 157 require.True(t, isSelectStatement, "analyzer expects a select statement") 158 159 reservedVars := sqlparser.NewReservedVars("vtg", vars) 160 semTable, err := semantics.Analyze(sel, "", &semantics.FakeSI{}) 161 require.NoError(t, err) 162 163 return semTable, reservedVars, sel 164 }