vitess.io/vitess@v0.16.2/go/vt/vtgate/semantics/analyzer_update_test.go (about)

     1  /*
     2  Copyright 2022 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package semantics
    18  
    19  import (
    20  	"testing"
    21  
    22  	"github.com/stretchr/testify/assert"
    23  	"github.com/stretchr/testify/require"
    24  
    25  	"vitess.io/vitess/go/vt/sqlparser"
    26  )
    27  
    28  func TestUpdBindingColName(t *testing.T) {
    29  	queries := []string{
    30  		"update tabl set col = 1",
    31  		"update t2 set uid = 5",
    32  		"update tabl set tabl.col = 1 ",
    33  		"update tabl set d.tabl.col = 5",
    34  		"update d.tabl set col = 1",
    35  		"update d.tabl set tabl.col = 5",
    36  		"update d.tabl set d.tabl.col = 1",
    37  	}
    38  	for _, query := range queries {
    39  		t.Run(query, func(t *testing.T) {
    40  			stmt, semTable := parseAndAnalyze(t, query, "d")
    41  			upd, _ := stmt.(*sqlparser.Update)
    42  			t1 := upd.TableExprs[0].(*sqlparser.AliasedTableExpr)
    43  			ts := semTable.TableSetFor(t1)
    44  			assert.Equal(t, SingleTableSet(0), ts)
    45  
    46  			updExpr := extractFromUpdateSet(upd, 0)
    47  			recursiveDeps := semTable.RecursiveDeps(updExpr.Name)
    48  			assert.Equal(t, T1, recursiveDeps, query)
    49  			assert.Equal(t, T1, semTable.DirectDeps(updExpr.Name), query)
    50  			assert.Equal(t, 1, recursiveDeps.NumberOfTables(), "number of tables is wrong")
    51  
    52  			recursiveDeps = semTable.RecursiveDeps(updExpr.Expr)
    53  			assert.Equal(t, None, recursiveDeps, query)
    54  			assert.Equal(t, None, semTable.DirectDeps(updExpr.Expr), query)
    55  		})
    56  	}
    57  }
    58  
    59  func TestUpdBindingExpr(t *testing.T) {
    60  	queries := []string{
    61  		"update tabl set col = col",
    62  		"update tabl set d.tabl.col = tabl.col",
    63  		"update d.tabl set col = d.tabl.col",
    64  		"update d.tabl set tabl.col = col",
    65  	}
    66  	for _, query := range queries {
    67  		t.Run(query, func(t *testing.T) {
    68  			stmt, semTable := parseAndAnalyze(t, query, "d")
    69  			upd, _ := stmt.(*sqlparser.Update)
    70  			t1 := upd.TableExprs[0].(*sqlparser.AliasedTableExpr)
    71  			ts := semTable.TableSetFor(t1)
    72  			assert.Equal(t, SingleTableSet(0), ts)
    73  
    74  			updExpr := extractFromUpdateSet(upd, 0)
    75  			recursiveDeps := semTable.RecursiveDeps(updExpr.Name)
    76  			assert.Equal(t, T1, recursiveDeps, query)
    77  			assert.Equal(t, T1, semTable.DirectDeps(updExpr.Name), query)
    78  			assert.Equal(t, 1, recursiveDeps.NumberOfTables(), "number of tables is wrong")
    79  
    80  			recursiveDeps = semTable.RecursiveDeps(updExpr.Expr)
    81  			assert.Equal(t, T1, recursiveDeps, query)
    82  			assert.Equal(t, T1, semTable.DirectDeps(updExpr.Expr), query)
    83  			assert.Equal(t, 1, recursiveDeps.NumberOfTables(), "number of tables is wrong")
    84  		})
    85  	}
    86  }
    87  
    88  func TestUpdSetSubquery(t *testing.T) {
    89  	queries := []string{
    90  		"update tabl set col = (select id from a)",
    91  		"update tabl set col = (select id from a)+1",
    92  		"update tabl set col = 1 IN (select id from a)",
    93  		"update tabl set col = (select id from a), t = (select x from a)",
    94  	}
    95  	for _, query := range queries {
    96  		t.Run(query, func(t *testing.T) {
    97  			stmt, semTable := parseAndAnalyze(t, query, "d")
    98  			upd, _ := stmt.(*sqlparser.Update)
    99  			t1 := upd.TableExprs[0].(*sqlparser.AliasedTableExpr)
   100  			ts := semTable.TableSetFor(t1)
   101  			assert.Equal(t, SingleTableSet(0), ts)
   102  
   103  			updExpr := extractFromUpdateSet(upd, 0)
   104  			recursiveDeps := semTable.RecursiveDeps(updExpr.Name)
   105  			assert.Equal(t, T1, recursiveDeps, query)
   106  			assert.Equal(t, T1, semTable.DirectDeps(updExpr.Name), query)
   107  			assert.Equal(t, 1, recursiveDeps.NumberOfTables(), "number of tables is wrong")
   108  
   109  			extractedSubqs := semTable.SubqueryMap[upd]
   110  			require.Len(t, extractedSubqs, len(upd.Exprs))
   111  
   112  			for _, esubq := range extractedSubqs {
   113  				subq := esubq.Subquery
   114  				extractedSubq := semTable.SubqueryRef[subq]
   115  				assert.True(t, sqlparser.Equals.Expr(extractedSubq.Subquery, subq))
   116  			}
   117  		})
   118  	}
   119  }
   120  
   121  func TestUpdWhereSubquery(t *testing.T) {
   122  	queries := []string{
   123  		"update tabl set col = 1 where id = (select id from a)",
   124  		"update tabl set col = 1 where id IN (select id from a)",
   125  		"update tabl set col = 1 where exists (select id from a)",
   126  		"update tabl set col = 1 where 1 = (select id from a)",
   127  		"update tabl set col = 1 where exists (select id from a) and id > (select name from city) and col < (select i from a)",
   128  	}
   129  	for _, query := range queries {
   130  		t.Run(query, func(t *testing.T) {
   131  			stmt, semTable := parseAndAnalyze(t, query, "d")
   132  			upd, _ := stmt.(*sqlparser.Update)
   133  			t1 := upd.TableExprs[0].(*sqlparser.AliasedTableExpr)
   134  			ts := semTable.TableSetFor(t1)
   135  			assert.Equal(t, SingleTableSet(0), ts)
   136  
   137  			extractedSubqs := semTable.SubqueryMap[upd]
   138  			require.Len(t, extractedSubqs, len(sqlparser.SplitAndExpression(nil, upd.Where.Expr)))
   139  
   140  			for _, esubq := range extractedSubqs {
   141  				subq := esubq.Subquery
   142  				extractedSubq := semTable.SubqueryRef[subq]
   143  				assert.True(t, sqlparser.Equals.Expr(extractedSubq.Subquery, subq))
   144  			}
   145  		})
   146  	}
   147  }
   148  
   149  func TestUpdSetAndWhereSubquery(t *testing.T) {
   150  	queries := []string{
   151  		"update tabl set col = (select b from alpha) where id = (select id from a)",
   152  		"update tabl set col = (select b from alpha) where exists (select id from a)",
   153  		"update tabl set col = 1+(select b from alpha) where 1 > (select id from a)",
   154  	}
   155  	for _, query := range queries {
   156  		t.Run(query, func(t *testing.T) {
   157  			stmt, semTable := parseAndAnalyze(t, query, "d")
   158  			upd, _ := stmt.(*sqlparser.Update)
   159  			t1 := upd.TableExprs[0].(*sqlparser.AliasedTableExpr)
   160  			ts := semTable.TableSetFor(t1)
   161  			assert.Equal(t, SingleTableSet(0), ts)
   162  
   163  			extractedSubqs := semTable.SubqueryMap[upd]
   164  			require.Len(t, extractedSubqs, len(sqlparser.SplitAndExpression(nil, upd.Where.Expr))+len(upd.Exprs))
   165  
   166  			for _, esubq := range extractedSubqs {
   167  				subq := esubq.Subquery
   168  				extractedSubq := semTable.SubqueryRef[subq]
   169  				assert.True(t, sqlparser.Equals.Expr(extractedSubq.Subquery, subq))
   170  			}
   171  		})
   172  	}
   173  }
   174  
   175  func extractFromUpdateSet(in *sqlparser.Update, idx int) *sqlparser.UpdateExpr {
   176  	return in.Exprs[idx]
   177  }