github.com/pingcap/tidb/parser@v0.0.0-20231013125129-93a834a6bf8d/ast/functions_test.go (about)

     1  // Copyright 2017 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package ast_test
    15  
    16  import (
    17  	"testing"
    18  
    19  	"github.com/pingcap/tidb/parser"
    20  	. "github.com/pingcap/tidb/parser/ast"
    21  	"github.com/pingcap/tidb/parser/mysql"
    22  	"github.com/pingcap/tidb/parser/test_driver"
    23  	"github.com/stretchr/testify/require"
    24  )
    25  
    26  func TestFunctionsVisitorCover(t *testing.T) {
    27  	valueExpr := NewValueExpr(42, mysql.DefaultCharset, mysql.DefaultCollationName)
    28  	stmts := []Node{
    29  		&AggregateFuncExpr{Args: []ExprNode{valueExpr}},
    30  		&FuncCallExpr{Args: []ExprNode{valueExpr}},
    31  		&FuncCastExpr{Expr: valueExpr},
    32  		&WindowFuncExpr{Spec: WindowSpec{}},
    33  	}
    34  
    35  	for _, stmt := range stmts {
    36  		stmt.Accept(visitor{})
    37  		stmt.Accept(visitor1{})
    38  	}
    39  }
    40  
    41  func TestFuncCallExprRestore(t *testing.T) {
    42  	testCases := []NodeRestoreTestCase{
    43  		{"JSON_ARRAYAGG(attribute)", "JSON_ARRAYAGG(`attribute`)"},
    44  		{"JSON_OBJECTAGG(attribute, value)", "JSON_OBJECTAGG(`attribute`, `value`)"},
    45  		{"ABS(-1024)", "ABS(-1024)"},
    46  		{"ACOS(3.14)", "ACOS(3.14)"},
    47  		{"CONV('a',16,2)", "CONV(_UTF8MB4'a', 16, 2)"},
    48  		{"COS(PI())", "COS(PI())"},
    49  		{"RAND()", "RAND()"},
    50  		{"ADDDATE('2000-01-01', 1)", "ADDDATE(_UTF8MB4'2000-01-01', INTERVAL 1 DAY)"},
    51  		{"DATE_ADD('2000-01-01', INTERVAL 1 DAY)", "DATE_ADD(_UTF8MB4'2000-01-01', INTERVAL 1 DAY)"},
    52  		{"DATE_ADD('2000-01-01', INTERVAL '1 1:12:23.100000' DAY_MICROSECOND)", "DATE_ADD(_UTF8MB4'2000-01-01', INTERVAL _UTF8MB4'1 1:12:23.100000' DAY_MICROSECOND)"},
    53  		{"EXTRACT(DAY FROM '2000-01-01')", "EXTRACT(DAY FROM _UTF8MB4'2000-01-01')"},
    54  		{"extract(day from '1999-01-01')", "EXTRACT(DAY FROM _UTF8MB4'1999-01-01')"},
    55  		{"GET_FORMAT(DATE, 'EUR')", "GET_FORMAT(DATE, _UTF8MB4'EUR')"},
    56  		{"POSITION('a' IN 'abc')", "POSITION(_UTF8MB4'a' IN _UTF8MB4'abc')"},
    57  		{"TRIM('  bar   ')", "TRIM(_UTF8MB4'  bar   ')"},
    58  		{"TRIM('a' FROM '  bar   ')", "TRIM(_UTF8MB4'a' FROM _UTF8MB4'  bar   ')"},
    59  		{"TRIM(LEADING FROM '  bar   ')", "TRIM(LEADING _UTF8MB4' ' FROM _UTF8MB4'  bar   ')"},
    60  		{"TRIM(BOTH FROM '  bar   ')", "TRIM(BOTH _UTF8MB4' ' FROM _UTF8MB4'  bar   ')"},
    61  		{"TRIM(TRAILING FROM '  bar   ')", "TRIM(TRAILING _UTF8MB4' ' FROM _UTF8MB4'  bar   ')"},
    62  		{"TRIM(LEADING 'x' FROM 'xxxyxxx')", "TRIM(LEADING _UTF8MB4'x' FROM _UTF8MB4'xxxyxxx')"},
    63  		{"TRIM(BOTH 'x' FROM 'xxxyxxx')", "TRIM(BOTH _UTF8MB4'x' FROM _UTF8MB4'xxxyxxx')"},
    64  		{"TRIM(TRAILING 'x' FROM 'xxxyxxx')", "TRIM(TRAILING _UTF8MB4'x' FROM _UTF8MB4'xxxyxxx')"},
    65  		{"TRIM(BOTH col1 FROM col2)", "TRIM(BOTH `col1` FROM `col2`)"},
    66  		{"DATE_ADD('2008-01-02', INTERVAL INTERVAL(1, 0, 1) DAY)", "DATE_ADD(_UTF8MB4'2008-01-02', INTERVAL INTERVAL(1, 0, 1) DAY)"},
    67  		{"BENCHMARK(1000000, AES_ENCRYPT('text', UNHEX('F3229A0B371ED2D9441B830D21A390C3')))", "BENCHMARK(1000000, AES_ENCRYPT(_UTF8MB4'text', UNHEX(_UTF8MB4'F3229A0B371ED2D9441B830D21A390C3')))"},
    68  		{"SUBSTRING('Quadratically', 5)", "SUBSTRING(_UTF8MB4'Quadratically', 5)"},
    69  		{"SUBSTRING('Quadratically' FROM 5)", "SUBSTRING(_UTF8MB4'Quadratically', 5)"},
    70  		{"SUBSTRING('Quadratically', 5, 6)", "SUBSTRING(_UTF8MB4'Quadratically', 5, 6)"},
    71  		{"SUBSTRING('Quadratically' FROM 5 FOR 6)", "SUBSTRING(_UTF8MB4'Quadratically', 5, 6)"},
    72  		{"MASTER_POS_WAIT(@log_name, @log_pos, @timeout, @channel_name)", "MASTER_POS_WAIT(@`log_name`, @`log_pos`, @`timeout`, @`channel_name`)"},
    73  		{"JSON_TYPE('[123]')", "JSON_TYPE(_UTF8MB4'[123]')"},
    74  		{"bit_and(all c1)", "BIT_AND(`c1`)"},
    75  		{"nextval(seq)", "NEXTVAL(`seq`)"},
    76  		{"nextval(test.seq)", "NEXTVAL(`test`.`seq`)"},
    77  		{"lastval(seq)", "LASTVAL(`seq`)"},
    78  		{"lastval(test.seq)", "LASTVAL(`test`.`seq`)"},
    79  		{"setval(seq, 100)", "SETVAL(`seq`, 100)"},
    80  		{"setval(test.seq, 100)", "SETVAL(`test`.`seq`, 100)"},
    81  		{"next value for seq", "NEXTVAL(`seq`)"},
    82  		{"next value for test.seq", "NEXTVAL(`test`.`seq`)"},
    83  		{"next value for sequence", "NEXTVAL(`sequence`)"},
    84  		{"NeXt vAluE for seQuEncE2", "NEXTVAL(`seQuEncE2`)"},
    85  		{"NeXt vAluE for test.seQuEncE2", "NEXTVAL(`test`.`seQuEncE2`)"},
    86  		{"weight_string(a)", "WEIGHT_STRING(`a`)"},
    87  		{"Weight_stRing(test.a)", "WEIGHT_STRING(`test`.`a`)"},
    88  		{"weight_string('a')", "WEIGHT_STRING(_UTF8MB4'a')"},
    89  		// Expressions with collations of different charsets will lead to an error in MySQL, but the error check should be done in TiDB, so it's valid here.
    90  		{"weight_string('a' collate utf8_general_ci collate utf8mb4_general_ci)", "WEIGHT_STRING(_UTF8MB4'a' COLLATE utf8_general_ci COLLATE utf8mb4_general_ci)"},
    91  		{"weight_string(_utf8 'a' collate utf8_general_ci)", "WEIGHT_STRING(_UTF8'a' COLLATE utf8_general_ci)"},
    92  		{"weight_string(_utf8 'a')", "WEIGHT_STRING(_UTF8'a')"},
    93  		{"weight_string(a as char(5))", "WEIGHT_STRING(`a` AS CHAR(5))"},
    94  		{"weight_string(a as character(5))", "WEIGHT_STRING(`a` AS CHAR(5))"},
    95  		{"weight_string(a as binary(5))", "WEIGHT_STRING(`a` AS BINARY(5))"},
    96  		{"hex(weight_string('abc' as binary(5)))", "HEX(WEIGHT_STRING(_UTF8MB4'abc' AS BINARY(5)))"},
    97  		{"soundex(attr)", "SOUNDEX(`attr`)"},
    98  		{"soundex('string')", "SOUNDEX(_UTF8MB4'string')"},
    99  	}
   100  	extractNodeFunc := func(node Node) Node {
   101  		return node.(*SelectStmt).Fields.Fields[0].Expr
   102  	}
   103  	runNodeRestoreTest(t, testCases, "select %s", extractNodeFunc)
   104  }
   105  
   106  func TestFuncCastExprRestore(t *testing.T) {
   107  	testCases := []NodeRestoreTestCase{
   108  		{"CONVERT('Müller' USING UtF8)", "CONVERT(_UTF8MB4'Müller' USING 'utf8')"},
   109  		{"CONVERT('Müller' USING UtF8Mb4)", "CONVERT(_UTF8MB4'Müller' USING 'utf8mb4')"},
   110  		{"CONVERT('Müller', CHAR(32) CHARACTER SET UtF8)", "CONVERT(_UTF8MB4'Müller', CHAR(32) CHARSET UTF8)"},
   111  		{"CAST('test' AS CHAR CHARACTER SET UtF8)", "CAST(_UTF8MB4'test' AS CHAR CHARSET UTF8)"},
   112  		{"BINARY 'New York'", "BINARY _UTF8MB4'New York'"},
   113  	}
   114  	extractNodeFunc := func(node Node) Node {
   115  		return node.(*SelectStmt).Fields.Fields[0].Expr
   116  	}
   117  	runNodeRestoreTest(t, testCases, "select %s", extractNodeFunc)
   118  }
   119  
   120  func TestAggregateFuncExprRestore(t *testing.T) {
   121  	testCases := []NodeRestoreTestCase{
   122  		{"AVG(test_score)", "AVG(`test_score`)"},
   123  		{"AVG(distinct test_score)", "AVG(DISTINCT `test_score`)"},
   124  		{"BIT_AND(test_score)", "BIT_AND(`test_score`)"},
   125  		{"BIT_OR(test_score)", "BIT_OR(`test_score`)"},
   126  		{"BIT_XOR(test_score)", "BIT_XOR(`test_score`)"},
   127  		{"COUNT(test_score)", "COUNT(`test_score`)"},
   128  		{"COUNT(*)", "COUNT(1)"},
   129  		{"COUNT(DISTINCT scores, results)", "COUNT(DISTINCT `scores`, `results`)"},
   130  		{"MIN(test_score)", "MIN(`test_score`)"},
   131  		{"MIN(DISTINCT test_score)", "MIN(DISTINCT `test_score`)"},
   132  		{"MAX(test_score)", "MAX(`test_score`)"},
   133  		{"MAX(DISTINCT test_score)", "MAX(DISTINCT `test_score`)"},
   134  		{"STD(test_score)", "STDDEV_POP(`test_score`)"},
   135  		{"STDDEV(test_score)", "STDDEV_POP(`test_score`)"},
   136  		{"STDDEV_POP(test_score)", "STDDEV_POP(`test_score`)"},
   137  		{"STDDEV_SAMP(test_score)", "STDDEV_SAMP(`test_score`)"},
   138  		{"SUM(test_score)", "SUM(`test_score`)"},
   139  		{"SUM(DISTINCT test_score)", "SUM(DISTINCT `test_score`)"},
   140  		{"VAR_POP(test_score)", "VAR_POP(`test_score`)"},
   141  		{"VAR_SAMP(test_score)", "VAR_SAMP(`test_score`)"},
   142  		{"VARIANCE(test_score)", "VAR_POP(`test_score`)"},
   143  		{"JSON_OBJECTAGG(test_score, results)", "JSON_OBJECTAGG(`test_score`, `results`)"},
   144  		{"GROUP_CONCAT(a)", "GROUP_CONCAT(`a` SEPARATOR ',')"},
   145  		{"GROUP_CONCAT(a separator '--')", "GROUP_CONCAT(`a` SEPARATOR '--')"},
   146  		{"GROUP_CONCAT(a order by b desc, c)", "GROUP_CONCAT(`a` ORDER BY `b` DESC,`c` SEPARATOR ',')"},
   147  		{"GROUP_CONCAT(a order by b desc, c separator '--')", "GROUP_CONCAT(`a` ORDER BY `b` DESC,`c` SEPARATOR '--')"},
   148  	}
   149  	extractNodeFunc := func(node Node) Node {
   150  		return node.(*SelectStmt).Fields.Fields[0].Expr
   151  	}
   152  	runNodeRestoreTest(t, testCases, "select %s", extractNodeFunc)
   153  }
   154  
   155  func TestConvert(t *testing.T) {
   156  	// Test case for CONVERT(expr USING transcoding_name).
   157  	cases := []struct {
   158  		SQL          string
   159  		CharsetName  string
   160  		ErrorMessage string
   161  	}{
   162  		{`SELECT CONVERT("abc" USING "latin1")`, "latin1", ""},
   163  		{`SELECT CONVERT("abc" USING laTiN1)`, "latin1", ""},
   164  		{`SELECT CONVERT("abc" USING "binary")`, "binary", ""},
   165  		{`SELECT CONVERT("abc" USING biNaRy)`, "binary", ""},
   166  		{`SELECT CONVERT(a USING a)`, "", `[parser:1115]Unknown character set: 'a'`}, // TiDB issue #4436.
   167  		{`SELECT CONVERT("abc" USING CONCAT("utf", "8"))`, "", `[parser:1115]Unknown character set: 'CONCAT'`},
   168  	}
   169  	for _, testCase := range cases {
   170  		stmt, err := parser.New().ParseOneStmt(testCase.SQL, "", "")
   171  		if testCase.ErrorMessage != "" {
   172  			require.EqualError(t, err, testCase.ErrorMessage)
   173  			continue
   174  		}
   175  		require.NoError(t, err)
   176  
   177  		st := stmt.(*SelectStmt)
   178  		expr := st.Fields.Fields[0].Expr.(*FuncCallExpr)
   179  		charsetArg := expr.Args[1].(*test_driver.ValueExpr)
   180  		require.Equal(t, testCase.CharsetName, charsetArg.GetString())
   181  	}
   182  }
   183  
   184  func TestChar(t *testing.T) {
   185  	// Test case for CHAR(N USING charset_name)
   186  	cases := []struct {
   187  		SQL          string
   188  		CharsetName  string
   189  		ErrorMessage string
   190  	}{
   191  		{`SELECT CHAR("abc" USING "latin1")`, "latin1", ""},
   192  		{`SELECT CHAR("abc" USING laTiN1)`, "latin1", ""},
   193  		{`SELECT CHAR("abc" USING "binary")`, "binary", ""},
   194  		{`SELECT CHAR("abc" USING binary)`, "binary", ""},
   195  		{`SELECT CHAR(a USING a)`, "", `[parser:1115]Unknown character set: 'a'`},
   196  		{`SELECT CHAR("abc" USING CONCAT("utf", "8"))`, "", `[parser:1115]Unknown character set: 'CONCAT'`},
   197  	}
   198  	for _, testCase := range cases {
   199  		stmt, err := parser.New().ParseOneStmt(testCase.SQL, "", "")
   200  		if testCase.ErrorMessage != "" {
   201  			require.EqualError(t, err, testCase.ErrorMessage)
   202  			continue
   203  		}
   204  		require.NoError(t, err)
   205  
   206  		st := stmt.(*SelectStmt)
   207  		expr := st.Fields.Fields[0].Expr.(*FuncCallExpr)
   208  		charsetArg := expr.Args[1].(*test_driver.ValueExpr)
   209  		require.Equal(t, testCase.CharsetName, charsetArg.GetString())
   210  	}
   211  }
   212  
   213  func TestWindowFuncExprRestore(t *testing.T) {
   214  	testCases := []NodeRestoreTestCase{
   215  		{"RANK() OVER w", "RANK() OVER `w`"},
   216  		{"RANK() OVER (PARTITION BY a)", "RANK() OVER (PARTITION BY `a`)"},
   217  		{"MAX(DISTINCT a) OVER (PARTITION BY a)", "MAX(DISTINCT `a`) OVER (PARTITION BY `a`)"},
   218  		{"MAX(DISTINCTROW a) OVER (PARTITION BY a)", "MAX(DISTINCT `a`) OVER (PARTITION BY `a`)"},
   219  		{"MAX(DISTINCT ALL a) OVER (PARTITION BY a)", "MAX(DISTINCT `a`) OVER (PARTITION BY `a`)"},
   220  		{"MAX(ALL a) OVER (PARTITION BY a)", "MAX(`a`) OVER (PARTITION BY `a`)"},
   221  		{"FIRST_VALUE(val) IGNORE NULLS OVER (w)", "FIRST_VALUE(`val`) IGNORE NULLS OVER (`w`)"},
   222  		{"FIRST_VALUE(val) RESPECT NULLS OVER w", "FIRST_VALUE(`val`) OVER `w`"},
   223  		{"NTH_VALUE(val, 233) FROM LAST IGNORE NULLS OVER w", "NTH_VALUE(`val`, 233) FROM LAST IGNORE NULLS OVER `w`"},
   224  		{"NTH_VALUE(val, 233) FROM FIRST IGNORE NULLS OVER (w)", "NTH_VALUE(`val`, 233) IGNORE NULLS OVER (`w`)"},
   225  	}
   226  	extractNodeFunc := func(node Node) Node {
   227  		return node.(*SelectStmt).Fields.Fields[0].Expr
   228  	}
   229  	runNodeRestoreTest(t, testCases, "select %s from t", extractNodeFunc)
   230  }
   231  
   232  func TestGenericFuncRestore(t *testing.T) {
   233  	testCases := []NodeRestoreTestCase{
   234  		{"s.a()", "`s`.`a`()"},
   235  		{"`s`.`a`()", "`s`.`a`()"},
   236  		{"now()", "NOW()"},
   237  		{"`s`.`now`()", "`s`.`now`()"},
   238  		// FIXME: expectSQL should be `generic_func()`.
   239  		{"generic_func()", "GENERIC_FUNC()"},
   240  		{"`ident.1`.`ident.2`()", "`ident.1`.`ident.2`()"},
   241  	}
   242  	extractNodeFunc := func(node Node) Node {
   243  		return node.(*SelectStmt).Fields.Fields[0].Expr
   244  	}
   245  	runNodeRestoreTest(t, testCases, "select %s from t", extractNodeFunc)
   246  }