vitess.io/vitess@v0.16.2/go/vt/vtgate/semantics/analyzer_update_test.go (about) 1 /* 2 Copyright 2022 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package semantics 18 19 import ( 20 "testing" 21 22 "github.com/stretchr/testify/assert" 23 "github.com/stretchr/testify/require" 24 25 "vitess.io/vitess/go/vt/sqlparser" 26 ) 27 28 func TestUpdBindingColName(t *testing.T) { 29 queries := []string{ 30 "update tabl set col = 1", 31 "update t2 set uid = 5", 32 "update tabl set tabl.col = 1 ", 33 "update tabl set d.tabl.col = 5", 34 "update d.tabl set col = 1", 35 "update d.tabl set tabl.col = 5", 36 "update d.tabl set d.tabl.col = 1", 37 } 38 for _, query := range queries { 39 t.Run(query, func(t *testing.T) { 40 stmt, semTable := parseAndAnalyze(t, query, "d") 41 upd, _ := stmt.(*sqlparser.Update) 42 t1 := upd.TableExprs[0].(*sqlparser.AliasedTableExpr) 43 ts := semTable.TableSetFor(t1) 44 assert.Equal(t, SingleTableSet(0), ts) 45 46 updExpr := extractFromUpdateSet(upd, 0) 47 recursiveDeps := semTable.RecursiveDeps(updExpr.Name) 48 assert.Equal(t, T1, recursiveDeps, query) 49 assert.Equal(t, T1, semTable.DirectDeps(updExpr.Name), query) 50 assert.Equal(t, 1, recursiveDeps.NumberOfTables(), "number of tables is wrong") 51 52 recursiveDeps = semTable.RecursiveDeps(updExpr.Expr) 53 assert.Equal(t, None, recursiveDeps, query) 54 assert.Equal(t, None, semTable.DirectDeps(updExpr.Expr), query) 55 }) 56 } 57 } 58 59 func TestUpdBindingExpr(t *testing.T) { 60 queries := []string{ 61 "update tabl set col = col", 62 "update tabl set d.tabl.col = tabl.col", 63 "update d.tabl set col = d.tabl.col", 64 "update d.tabl set tabl.col = col", 65 } 66 for _, query := range queries { 67 t.Run(query, func(t *testing.T) { 68 stmt, semTable := parseAndAnalyze(t, query, "d") 69 upd, _ := stmt.(*sqlparser.Update) 70 t1 := upd.TableExprs[0].(*sqlparser.AliasedTableExpr) 71 ts := semTable.TableSetFor(t1) 72 assert.Equal(t, SingleTableSet(0), ts) 73 74 updExpr := extractFromUpdateSet(upd, 0) 75 recursiveDeps := semTable.RecursiveDeps(updExpr.Name) 76 assert.Equal(t, T1, recursiveDeps, query) 77 assert.Equal(t, T1, semTable.DirectDeps(updExpr.Name), query) 78 assert.Equal(t, 1, recursiveDeps.NumberOfTables(), "number of tables is wrong") 79 80 recursiveDeps = semTable.RecursiveDeps(updExpr.Expr) 81 assert.Equal(t, T1, recursiveDeps, query) 82 assert.Equal(t, T1, semTable.DirectDeps(updExpr.Expr), query) 83 assert.Equal(t, 1, recursiveDeps.NumberOfTables(), "number of tables is wrong") 84 }) 85 } 86 } 87 88 func TestUpdSetSubquery(t *testing.T) { 89 queries := []string{ 90 "update tabl set col = (select id from a)", 91 "update tabl set col = (select id from a)+1", 92 "update tabl set col = 1 IN (select id from a)", 93 "update tabl set col = (select id from a), t = (select x from a)", 94 } 95 for _, query := range queries { 96 t.Run(query, func(t *testing.T) { 97 stmt, semTable := parseAndAnalyze(t, query, "d") 98 upd, _ := stmt.(*sqlparser.Update) 99 t1 := upd.TableExprs[0].(*sqlparser.AliasedTableExpr) 100 ts := semTable.TableSetFor(t1) 101 assert.Equal(t, SingleTableSet(0), ts) 102 103 updExpr := extractFromUpdateSet(upd, 0) 104 recursiveDeps := semTable.RecursiveDeps(updExpr.Name) 105 assert.Equal(t, T1, recursiveDeps, query) 106 assert.Equal(t, T1, semTable.DirectDeps(updExpr.Name), query) 107 assert.Equal(t, 1, recursiveDeps.NumberOfTables(), "number of tables is wrong") 108 109 extractedSubqs := semTable.SubqueryMap[upd] 110 require.Len(t, extractedSubqs, len(upd.Exprs)) 111 112 for _, esubq := range extractedSubqs { 113 subq := esubq.Subquery 114 extractedSubq := semTable.SubqueryRef[subq] 115 assert.True(t, sqlparser.Equals.Expr(extractedSubq.Subquery, subq)) 116 } 117 }) 118 } 119 } 120 121 func TestUpdWhereSubquery(t *testing.T) { 122 queries := []string{ 123 "update tabl set col = 1 where id = (select id from a)", 124 "update tabl set col = 1 where id IN (select id from a)", 125 "update tabl set col = 1 where exists (select id from a)", 126 "update tabl set col = 1 where 1 = (select id from a)", 127 "update tabl set col = 1 where exists (select id from a) and id > (select name from city) and col < (select i from a)", 128 } 129 for _, query := range queries { 130 t.Run(query, func(t *testing.T) { 131 stmt, semTable := parseAndAnalyze(t, query, "d") 132 upd, _ := stmt.(*sqlparser.Update) 133 t1 := upd.TableExprs[0].(*sqlparser.AliasedTableExpr) 134 ts := semTable.TableSetFor(t1) 135 assert.Equal(t, SingleTableSet(0), ts) 136 137 extractedSubqs := semTable.SubqueryMap[upd] 138 require.Len(t, extractedSubqs, len(sqlparser.SplitAndExpression(nil, upd.Where.Expr))) 139 140 for _, esubq := range extractedSubqs { 141 subq := esubq.Subquery 142 extractedSubq := semTable.SubqueryRef[subq] 143 assert.True(t, sqlparser.Equals.Expr(extractedSubq.Subquery, subq)) 144 } 145 }) 146 } 147 } 148 149 func TestUpdSetAndWhereSubquery(t *testing.T) { 150 queries := []string{ 151 "update tabl set col = (select b from alpha) where id = (select id from a)", 152 "update tabl set col = (select b from alpha) where exists (select id from a)", 153 "update tabl set col = 1+(select b from alpha) where 1 > (select id from a)", 154 } 155 for _, query := range queries { 156 t.Run(query, func(t *testing.T) { 157 stmt, semTable := parseAndAnalyze(t, query, "d") 158 upd, _ := stmt.(*sqlparser.Update) 159 t1 := upd.TableExprs[0].(*sqlparser.AliasedTableExpr) 160 ts := semTable.TableSetFor(t1) 161 assert.Equal(t, SingleTableSet(0), ts) 162 163 extractedSubqs := semTable.SubqueryMap[upd] 164 require.Len(t, extractedSubqs, len(sqlparser.SplitAndExpression(nil, upd.Where.Expr))+len(upd.Exprs)) 165 166 for _, esubq := range extractedSubqs { 167 subq := esubq.Subquery 168 extractedSubq := semTable.SubqueryRef[subq] 169 assert.True(t, sqlparser.Equals.Expr(extractedSubq.Subquery, subq)) 170 } 171 }) 172 } 173 } 174 175 func extractFromUpdateSet(in *sqlparser.Update, idx int) *sqlparser.UpdateExpr { 176 return in.Exprs[idx] 177 }