github.com/pingcap/tidb/parser@v0.0.0-20231013125129-93a834a6bf8d/ast/expressions_test.go (about)

     1  // Copyright 2017 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package ast_test
    15  
    16  import (
    17  	"testing"
    18  
    19  	. "github.com/pingcap/tidb/parser/ast"
    20  	"github.com/pingcap/tidb/parser/format"
    21  	"github.com/pingcap/tidb/parser/mysql"
    22  	"github.com/stretchr/testify/require"
    23  )
    24  
    25  type checkVisitor struct{}
    26  
    27  func (v checkVisitor) Enter(in Node) (Node, bool) {
    28  	if e, ok := in.(*checkExpr); ok {
    29  		e.enterCnt++
    30  		return in, true
    31  	}
    32  	return in, false
    33  }
    34  
    35  func (v checkVisitor) Leave(in Node) (Node, bool) {
    36  	if e, ok := in.(*checkExpr); ok {
    37  		e.leaveCnt++
    38  	}
    39  	return in, true
    40  }
    41  
    42  type checkExpr struct {
    43  	ValueExpr
    44  
    45  	enterCnt int
    46  	leaveCnt int
    47  }
    48  
    49  func (n *checkExpr) Accept(v Visitor) (Node, bool) {
    50  	newNode, skipChildren := v.Enter(n)
    51  	if skipChildren {
    52  		return v.Leave(newNode)
    53  	}
    54  	n = newNode.(*checkExpr)
    55  	return v.Leave(n)
    56  }
    57  
    58  func (n *checkExpr) reset() {
    59  	n.enterCnt = 0
    60  	n.leaveCnt = 0
    61  }
    62  
    63  func TestExpresionsVisitorCover(t *testing.T) {
    64  	ce := &checkExpr{}
    65  	stmts :=
    66  		[]struct {
    67  			node             Node
    68  			expectedEnterCnt int
    69  			expectedLeaveCnt int
    70  		}{
    71  			{&BetweenExpr{Expr: ce, Left: ce, Right: ce}, 3, 3},
    72  			{&BinaryOperationExpr{L: ce, R: ce}, 2, 2},
    73  			{&CaseExpr{Value: ce, WhenClauses: []*WhenClause{{Expr: ce, Result: ce},
    74  				{Expr: ce, Result: ce}}, ElseClause: ce}, 6, 6},
    75  			{&ColumnNameExpr{Name: &ColumnName{}}, 0, 0},
    76  			{&CompareSubqueryExpr{L: ce, R: ce}, 2, 2},
    77  			{&DefaultExpr{Name: &ColumnName{}}, 0, 0},
    78  			{&ExistsSubqueryExpr{Sel: ce}, 1, 1},
    79  			{&IsNullExpr{Expr: ce}, 1, 1},
    80  			{&IsTruthExpr{Expr: ce}, 1, 1},
    81  			{NewParamMarkerExpr(0), 0, 0},
    82  			{&ParenthesesExpr{Expr: ce}, 1, 1},
    83  			{&PatternInExpr{Expr: ce, List: []ExprNode{ce, ce, ce}, Sel: ce}, 5, 5},
    84  			{&PatternLikeOrIlikeExpr{Expr: ce, Pattern: ce}, 2, 2},
    85  			{&PatternRegexpExpr{Expr: ce, Pattern: ce}, 2, 2},
    86  			{&PositionExpr{}, 0, 0},
    87  			{&RowExpr{Values: []ExprNode{ce, ce}}, 2, 2},
    88  			{&UnaryOperationExpr{V: ce}, 1, 1},
    89  			{NewValueExpr(0, mysql.DefaultCharset, mysql.DefaultCollationName), 0, 0},
    90  			{&ValuesExpr{Column: &ColumnNameExpr{Name: &ColumnName{}}}, 0, 0},
    91  			{&VariableExpr{Value: ce}, 1, 1},
    92  		}
    93  
    94  	for _, v := range stmts {
    95  		ce.reset()
    96  		v.node.Accept(checkVisitor{})
    97  		require.Equal(t, v.expectedEnterCnt, ce.enterCnt)
    98  		require.Equal(t, v.expectedLeaveCnt, ce.leaveCnt)
    99  		v.node.Accept(visitor1{})
   100  	}
   101  }
   102  
   103  func TestUnaryOperationExprRestore(t *testing.T) {
   104  	testCases := []NodeRestoreTestCase{
   105  		{"++1", "++1"},
   106  		{"--1", "--1"},
   107  		{"-+1", "-+1"},
   108  		{"-1", "-1"},
   109  		{"not true", "NOT TRUE"},
   110  		{"~3", "~3"},
   111  		{"!true", "!TRUE"},
   112  	}
   113  	extractNodeFunc := func(node Node) Node {
   114  		return node.(*SelectStmt).Fields.Fields[0].Expr
   115  	}
   116  	runNodeRestoreTest(t, testCases, "select %s", extractNodeFunc)
   117  }
   118  
   119  func TestColumnNameExprRestore(t *testing.T) {
   120  	testCases := []NodeRestoreTestCase{
   121  		{"abc", "`abc`"},
   122  		{"`abc`", "`abc`"},
   123  		{"`ab``c`", "`ab``c`"},
   124  		{"sabc.tABC", "`sabc`.`tABC`"},
   125  		{"dabc.sabc.tabc", "`dabc`.`sabc`.`tabc`"},
   126  		{"dabc.`sabc`.tabc", "`dabc`.`sabc`.`tabc`"},
   127  		{"`dABC`.`sabc`.tabc", "`dABC`.`sabc`.`tabc`"},
   128  	}
   129  	extractNodeFunc := func(node Node) Node {
   130  		return node.(*SelectStmt).Fields.Fields[0].Expr
   131  	}
   132  	runNodeRestoreTest(t, testCases, "select %s", extractNodeFunc)
   133  }
   134  
   135  func TestIsNullExprRestore(t *testing.T) {
   136  	testCases := []NodeRestoreTestCase{
   137  		{"a is null", "`a` IS NULL"},
   138  		{"a is not null", "`a` IS NOT NULL"},
   139  	}
   140  	extractNodeFunc := func(node Node) Node {
   141  		return node.(*SelectStmt).Fields.Fields[0].Expr
   142  	}
   143  	runNodeRestoreTest(t, testCases, "select %s", extractNodeFunc)
   144  }
   145  
   146  func TestIsTruthRestore(t *testing.T) {
   147  	testCases := []NodeRestoreTestCase{
   148  		{"a is true", "`a` IS TRUE"},
   149  		{"a is not true", "`a` IS NOT TRUE"},
   150  		{"a is FALSE", "`a` IS FALSE"},
   151  		{"a is not false", "`a` IS NOT FALSE"},
   152  	}
   153  	extractNodeFunc := func(node Node) Node {
   154  		return node.(*SelectStmt).Fields.Fields[0].Expr
   155  	}
   156  	runNodeRestoreTest(t, testCases, "select %s", extractNodeFunc)
   157  }
   158  
   159  func TestBetweenExprRestore(t *testing.T) {
   160  	testCases := []NodeRestoreTestCase{
   161  		{"b between 1 and 2", "`b` BETWEEN 1 AND 2"},
   162  		{"b not between 1 and 2", "`b` NOT BETWEEN 1 AND 2"},
   163  		{"b between a and b", "`b` BETWEEN `a` AND `b`"},
   164  		{"b between '' and 'b'", "`b` BETWEEN _UTF8MB4'' AND _UTF8MB4'b'"},
   165  		{"b between '2018-11-01' and '2018-11-02'", "`b` BETWEEN _UTF8MB4'2018-11-01' AND _UTF8MB4'2018-11-02'"},
   166  	}
   167  	extractNodeFunc := func(node Node) Node {
   168  		return node.(*SelectStmt).Fields.Fields[0].Expr
   169  	}
   170  	runNodeRestoreTest(t, testCases, "select %s", extractNodeFunc)
   171  }
   172  
   173  func TestCaseExpr(t *testing.T) {
   174  	testCases := []NodeRestoreTestCase{
   175  		{"case when 1 then 2 end", "CASE WHEN 1 THEN 2 END"},
   176  		{"case when 1 then 'a' when 2 then 'b' end", "CASE WHEN 1 THEN _UTF8MB4'a' WHEN 2 THEN _UTF8MB4'b' END"},
   177  		{"case when 1 then 'a' when 2 then 'b' else 'c' end", "CASE WHEN 1 THEN _UTF8MB4'a' WHEN 2 THEN _UTF8MB4'b' ELSE _UTF8MB4'c' END"},
   178  		{"case when 'a'!=1 then true else false end", "CASE WHEN _UTF8MB4'a'!=1 THEN TRUE ELSE FALSE END"},
   179  		{"case a when 'a' then true else false end", "CASE `a` WHEN _UTF8MB4'a' THEN TRUE ELSE FALSE END"},
   180  	}
   181  	extractNodeFunc := func(node Node) Node {
   182  		return node.(*SelectStmt).Fields.Fields[0].Expr
   183  	}
   184  	runNodeRestoreTest(t, testCases, "select %s", extractNodeFunc)
   185  }
   186  
   187  func TestBinaryOperationExpr(t *testing.T) {
   188  	testCases := []NodeRestoreTestCase{
   189  		{"'a'!=1", "_UTF8MB4'a'!=1"},
   190  		{"a!=1", "`a`!=1"},
   191  		{"3<5", "3<5"},
   192  		{"10>5", "10>5"},
   193  		{"3+5", "3+5"},
   194  		{"3-5", "3-5"},
   195  		{"a<>5", "`a`!=5"},
   196  		{"a=1", "`a`=1"},
   197  		{"a mod 2", "`a`%2"},
   198  		{"a div 2", "`a` DIV 2"},
   199  		{"true and true", "TRUE AND TRUE"},
   200  		{"false or false", "FALSE OR FALSE"},
   201  		{"true xor false", "TRUE XOR FALSE"},
   202  		{"3 & 4", "3&4"},
   203  		{"5 | 6", "5|6"},
   204  		{"7 ^ 8", "7^8"},
   205  		{"9 << 10", "9<<10"},
   206  		{"11 >> 12", "11>>12"},
   207  	}
   208  	extractNodeFunc := func(node Node) Node {
   209  		return node.(*SelectStmt).Fields.Fields[0].Expr
   210  	}
   211  	runNodeRestoreTest(t, testCases, "select %s", extractNodeFunc)
   212  }
   213  
   214  func TestBinaryOperationExprWithFlags(t *testing.T) {
   215  	testCases := []NodeRestoreTestCase{
   216  		{"'a'!=1", "_UTF8MB4'a' != 1"},
   217  		{"a!=1", "`a` != 1"},
   218  		{"3<5", "3 < 5"},
   219  		{"10>5", "10 > 5"},
   220  		{"3+5", "3 + 5"},
   221  		{"3-5", "3 - 5"},
   222  		{"a<>5", "`a` != 5"},
   223  		{"a=1", "`a` = 1"},
   224  	}
   225  	extractNodeFunc := func(node Node) Node {
   226  		return node.(*SelectStmt).Fields.Fields[0].Expr
   227  	}
   228  	flags := format.DefaultRestoreFlags | format.RestoreSpacesAroundBinaryOperation
   229  	runNodeRestoreTestWithFlags(t, testCases, "select %s", extractNodeFunc, flags)
   230  }
   231  
   232  func TestParenthesesExpr(t *testing.T) {
   233  	testCases := []NodeRestoreTestCase{
   234  		{"(1+2)*3", "(1+2)*3"},
   235  		{"1+2*3", "1+2*3"},
   236  	}
   237  	extractNodeFunc := func(node Node) Node {
   238  		return node.(*SelectStmt).Fields.Fields[0].Expr
   239  	}
   240  	runNodeRestoreTest(t, testCases, "select %s", extractNodeFunc)
   241  }
   242  
   243  func TestWhenClause(t *testing.T) {
   244  	testCases := []NodeRestoreTestCase{
   245  		{"when 1 then 2", "WHEN 1 THEN 2"},
   246  		{"when 1 then 'a'", "WHEN 1 THEN _UTF8MB4'a'"},
   247  		{"when 'a'!=1 then true", "WHEN _UTF8MB4'a'!=1 THEN TRUE"},
   248  	}
   249  	extractNodeFunc := func(node Node) Node {
   250  		return node.(*SelectStmt).Fields.Fields[0].Expr.(*CaseExpr).WhenClauses[0]
   251  	}
   252  	runNodeRestoreTest(t, testCases, "select case %s end", extractNodeFunc)
   253  }
   254  
   255  func TestDefaultExpr(t *testing.T) {
   256  	testCases := []NodeRestoreTestCase{
   257  		{"default", "DEFAULT"},
   258  		{"default(i)", "DEFAULT(`i`)"},
   259  	}
   260  	extractNodeFunc := func(node Node) Node {
   261  		return node.(*InsertStmt).Lists[0][0]
   262  	}
   263  	runNodeRestoreTest(t, testCases, "insert into t values(%s)", extractNodeFunc)
   264  }
   265  
   266  func TestPatternInExprRestore(t *testing.T) {
   267  	testCases := []NodeRestoreTestCase{
   268  		{"'a' in ('b')", "_UTF8MB4'a' IN (_UTF8MB4'b')"},
   269  		{"2 in (0,3,7)", "2 IN (0,3,7)"},
   270  		{"2 not in (0,3,7)", "2 NOT IN (0,3,7)"},
   271  		{"2 in (select 2)", "2 IN (SELECT 2)"},
   272  		{"2 not in (select 2)", "2 NOT IN (SELECT 2)"},
   273  	}
   274  	extractNodeFunc := func(node Node) Node {
   275  		return node.(*SelectStmt).Fields.Fields[0].Expr
   276  	}
   277  	runNodeRestoreTest(t, testCases, "select %s", extractNodeFunc)
   278  }
   279  
   280  func TestPatternLikeExprRestore(t *testing.T) {
   281  	testCases := []NodeRestoreTestCase{
   282  		{"a like 't1'", "`a` LIKE _UTF8MB4't1'"},
   283  		{"a like 't1%'", "`a` LIKE _UTF8MB4't1%'"},
   284  		{"a like '%t1%'", "`a` LIKE _UTF8MB4'%t1%'"},
   285  		{"a like '%t1_|'", "`a` LIKE _UTF8MB4'%t1_|'"},
   286  		{"a not like 't1'", "`a` NOT LIKE _UTF8MB4't1'"},
   287  		{"a not like 't1%'", "`a` NOT LIKE _UTF8MB4't1%'"},
   288  		{"a not like '%D%v%'", "`a` NOT LIKE _UTF8MB4'%D%v%'"},
   289  		{"a not like '%t1_|'", "`a` NOT LIKE _UTF8MB4'%t1_|'"},
   290  	}
   291  	extractNodeFunc := func(node Node) Node {
   292  		return node.(*SelectStmt).Fields.Fields[0].Expr
   293  	}
   294  	runNodeRestoreTest(t, testCases, "select %s", extractNodeFunc)
   295  }
   296  
   297  func TestValuesExpr(t *testing.T) {
   298  	testCases := []NodeRestoreTestCase{
   299  		{"values(a)", "VALUES(`a`)"},
   300  		{"values(a)+values(b)", "VALUES(`a`)+VALUES(`b`)"},
   301  	}
   302  	extractNodeFunc := func(node Node) Node {
   303  		return node.(*InsertStmt).OnDuplicate[0].Expr
   304  	}
   305  	runNodeRestoreTest(t, testCases, "insert into t values (1,2,3) on duplicate key update c=%s", extractNodeFunc)
   306  }
   307  
   308  func TestPatternRegexpExprRestore(t *testing.T) {
   309  	testCases := []NodeRestoreTestCase{
   310  		{"a regexp 't1'", "`a` REGEXP _UTF8MB4't1'"},
   311  		{"a regexp '^[abc][0-9]{11}|ok$'", "`a` REGEXP _UTF8MB4'^[abc][0-9]{11}|ok$'"},
   312  		{"a rlike 't1'", "`a` REGEXP _UTF8MB4't1'"},
   313  		{"a rlike '^[abc][0-9]{11}|ok$'", "`a` REGEXP _UTF8MB4'^[abc][0-9]{11}|ok$'"},
   314  		{"a not regexp 't1'", "`a` NOT REGEXP _UTF8MB4't1'"},
   315  		{"a not regexp '^[abc][0-9]{11}|ok$'", "`a` NOT REGEXP _UTF8MB4'^[abc][0-9]{11}|ok$'"},
   316  		{"a not rlike 't1'", "`a` NOT REGEXP _UTF8MB4't1'"},
   317  		{"a not rlike '^[abc][0-9]{11}|ok$'", "`a` NOT REGEXP _UTF8MB4'^[abc][0-9]{11}|ok$'"},
   318  	}
   319  	extractNodeFunc := func(node Node) Node {
   320  		return node.(*SelectStmt).Fields.Fields[0].Expr
   321  	}
   322  	runNodeRestoreTest(t, testCases, "select %s", extractNodeFunc)
   323  }
   324  
   325  func TestRowExprRestore(t *testing.T) {
   326  	testCases := []NodeRestoreTestCase{
   327  		{"(1,2)", "ROW(1,2)"},
   328  		{"(col1,col2)", "ROW(`col1`,`col2`)"},
   329  		{"row(1,2)", "ROW(1,2)"},
   330  		{"row(col1,col2)", "ROW(`col1`,`col2`)"},
   331  	}
   332  	extractNodeFunc := func(node Node) Node {
   333  		return node.(*SelectStmt).Where.(*BinaryOperationExpr).L
   334  	}
   335  	runNodeRestoreTest(t, testCases, "select 1 from t1 where %s = row(1,2)", extractNodeFunc)
   336  }
   337  
   338  func TestMaxValueExprRestore(t *testing.T) {
   339  	testCases := []NodeRestoreTestCase{
   340  		{"maxvalue", "MAXVALUE"},
   341  	}
   342  	extractNodeFunc := func(node Node) Node {
   343  		return node.(*AlterTableStmt).Specs[0].PartDefinitions[0].Clause.(*PartitionDefinitionClauseLessThan).Exprs[0]
   344  	}
   345  	runNodeRestoreTest(t, testCases, "alter table posts add partition ( partition p1 values less than %s)", extractNodeFunc)
   346  }
   347  
   348  func TestPositionExprRestore(t *testing.T) {
   349  	testCases := []NodeRestoreTestCase{
   350  		{"1", "1"},
   351  	}
   352  	extractNodeFunc := func(node Node) Node {
   353  		return node.(*SelectStmt).OrderBy.Items[0]
   354  	}
   355  	runNodeRestoreTest(t, testCases, "select * from t order by %s", extractNodeFunc)
   356  }
   357  
   358  func TestExistsSubqueryExprRestore(t *testing.T) {
   359  	testCases := []NodeRestoreTestCase{
   360  		{"EXISTS (SELECT 2)", "EXISTS (SELECT 2)"},
   361  		{"NOT EXISTS (SELECT 2)", "NOT EXISTS (SELECT 2)"},
   362  		{"NOT NOT EXISTS (SELECT 2)", "EXISTS (SELECT 2)"},
   363  		{"NOT NOT NOT EXISTS (SELECT 2)", "NOT EXISTS (SELECT 2)"},
   364  	}
   365  	extractNodeFunc := func(node Node) Node {
   366  		return node.(*SelectStmt).Where
   367  	}
   368  	runNodeRestoreTest(t, testCases, "select 1 from t1 where %s", extractNodeFunc)
   369  }
   370  
   371  func TestVariableExpr(t *testing.T) {
   372  	testCases := []NodeRestoreTestCase{
   373  		{"@a>1", "@`a`>1"},
   374  		{"@`aB`+1", "@`aB`+1"},
   375  		{"@'a':=1", "@`a`:=1"},
   376  		{"@`a``b`=4", "@`a``b`=4"},
   377  		{`@"aBC">1`, "@`aBC`>1"},
   378  		{"@`a`+1", "@`a`+1"},
   379  		{"@``", "@``"},
   380  		{"@", "@``"},
   381  		{"@@``", "@@``"},
   382  		{"@@var", "@@`var`"},
   383  		{"@@global.b='foo'", "@@GLOBAL.`b`=_UTF8MB4'foo'"},
   384  		{"@@session.'C'", "@@SESSION.`c`"},
   385  		{`@@local."aBc"`, "@@SESSION.`abc`"},
   386  	}
   387  	extractNodeFunc := func(node Node) Node {
   388  		return node.(*SelectStmt).Fields.Fields[0].Expr
   389  	}
   390  	runNodeRestoreTest(t, testCases, "select %s", extractNodeFunc)
   391  }
   392  
   393  func TestMatchAgainstExpr(t *testing.T) {
   394  	testCases := []NodeRestoreTestCase{
   395  		{`MATCH(content, title) AGAINST ('search for')`, "MATCH (`content`,`title`) AGAINST (_UTF8MB4'search for')"},
   396  		{`MATCH(content) AGAINST ('search for' IN BOOLEAN MODE)`, "MATCH (`content`) AGAINST (_UTF8MB4'search for' IN BOOLEAN MODE)"},
   397  		{`MATCH(content, title) AGAINST ('search for' WITH QUERY EXPANSION)`, "MATCH (`content`,`title`) AGAINST (_UTF8MB4'search for' WITH QUERY EXPANSION)"},
   398  		{`MATCH(content) AGAINST ('search for' IN NATURAL LANGUAGE MODE WITH QUERY EXPANSION)`, "MATCH (`content`) AGAINST (_UTF8MB4'search for' WITH QUERY EXPANSION)"},
   399  		{`MATCH(content) AGAINST ('search') AND id = 1`, "MATCH (`content`) AGAINST (_UTF8MB4'search') AND `id`=1"},
   400  		{`MATCH(content) AGAINST ('search') OR id = 1`, "MATCH (`content`) AGAINST (_UTF8MB4'search') OR `id`=1"},
   401  		{`MATCH(content) AGAINST (X'40404040' | X'01020304') OR id = 1`, "MATCH (`content`) AGAINST (x'40404040'|x'01020304') OR `id`=1"},
   402  	}
   403  	extractNodeFunc := func(node Node) Node {
   404  		return node.(*SelectStmt).Where
   405  	}
   406  	runNodeRestoreTest(t, testCases, "SELECT * FROM t WHERE %s", extractNodeFunc)
   407  }