github.com/pingcap/tidb/parser@v0.0.0-20231013125129-93a834a6bf8d/ast/format_test.go (about)

     1  package ast_test
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"testing"
     7  
     8  	"github.com/pingcap/tidb/parser"
     9  	"github.com/pingcap/tidb/parser/ast"
    10  	"github.com/stretchr/testify/require"
    11  )
    12  
    13  func getDefaultCharsetAndCollate() (string, string) {
    14  	return "utf8", "utf8_bin"
    15  }
    16  
    17  func TestAstFormat(t *testing.T) {
    18  	var testcases = []struct {
    19  		input  string
    20  		output string
    21  	}{
    22  		// Literals.
    23  		{`null`, `NULL`},
    24  		{`true`, `TRUE`},
    25  		{`350`, `350`},
    26  		{`001e-12`, `1e-12`}, // Float.
    27  		{`345.678`, `345.678`},
    28  		{`00.0001000`, `0.0001000`}, // Decimal.
    29  		{`null`, `NULL`},
    30  		{`"Hello, world"`, `"Hello, world"`},
    31  		{`'Hello, world'`, `"Hello, world"`},
    32  		{`'Hello, "world"'`, `"Hello, \"world\""`},
    33  		{`_utf8'你好'`, `"你好"`},
    34  		{`x'bcde'`, "x'bcde'"},
    35  		{`x''`, "x''"},
    36  		{`x'0035'`, "x'0035'"}, // Shouldn't trim leading zero.
    37  		{`b'00111111'`, `b'111111'`},
    38  		{`time'10:10:10.123'`, ast.TimeLiteral + `("10:10:10.123")`},
    39  		{`timestamp'1999-01-01 10:0:0.123'`, ast.TimestampLiteral + `("1999-01-01 10:0:0.123")`},
    40  		{`date '1700-01-01'`, ast.DateLiteral + `("1700-01-01")`},
    41  
    42  		// Expressions.
    43  		{`f between 30 and 50`, "`f` BETWEEN 30 AND 50"},
    44  		{`f not between 30 and 50`, "`f` NOT BETWEEN 30 AND 50"},
    45  		{`345 + "  hello  "`, `345 + "  hello  "`},
    46  		{`"hello world"    >=    'hello world'`, `"hello world" >= "hello world"`},
    47  		{`case 3 when 1 then false else true end`, `CASE 3 WHEN 1 THEN FALSE ELSE TRUE END`},
    48  		{`database.table.column`, "`database`.`table`.`column`"}, // ColumnNameExpr
    49  		{`3 is null`, `3 IS NULL`},
    50  		{`3 is not null`, `3 IS NOT NULL`},
    51  		{`3 is true`, `3 IS TRUE`},
    52  		{`3 is not true`, `3 IS NOT TRUE`},
    53  		{`3 is false`, `3 IS FALSE`},
    54  		{`  ( x is false  )`, "(`x` IS FALSE)"},
    55  		{`3 in ( a,b,"h",6 )`, "3 IN (`a`,`b`,\"h\",6)"},
    56  		{`3 not in ( a,b,"h",6 )`, "3 NOT IN (`a`,`b`,\"h\",6)"},
    57  		{`"abc" like '%b%'`, `"abc" LIKE "%b%"`},
    58  		{`"abc" not like '%b%'`, `"abc" NOT LIKE "%b%"`},
    59  		{`"abc" like '%b%' escape '_'`, `"abc" LIKE "%b%" ESCAPE '_'`},
    60  		{`"abc" regexp '.*bc?'`, `"abc" REGEXP ".*bc?"`},
    61  		{`"abc" not regexp '.*bc?'`, `"abc" NOT REGEXP ".*bc?"`},
    62  		{`-  4`, `-4`},
    63  		{`- ( - 4 ) `, `-(-4)`},
    64  		{`a%b`, "`a` % `b`"},
    65  		{`a%b+6`, "`a` % `b` + 6"},
    66  		{`a%(b+6)`, "`a` % (`b` + 6)"},
    67  		// Functions.
    68  		{` json_extract ( a,'$.b',"$.\"c d\"" ) `, "json_extract(`a`, \"$.b\", \"$.\\\"c d\\\"\")"},
    69  		{` length ( a )`, "length(`a`)"},
    70  		{`a -> '$.a'`, "json_extract(`a`, \"$.a\")"},
    71  		{`a.b ->> '$.a'`, "json_unquote(json_extract(`a`.`b`, \"$.a\"))"},
    72  		{`DATE_ADD('1970-01-01', interval 3 second)`, `date_add("1970-01-01", INTERVAL 3 SECOND)`},
    73  		{`TIMESTAMPDIFF(month, '2001-01-01', '2001-02-02 12:03:05.123')`, `timestampdiff(MONTH, "2001-01-01", "2001-02-02 12:03:05.123")`},
    74  		// Cast, Convert and Binary.
    75  		// There should not be spaces between 'cast' and '(' unless 'IGNORE_SPACE' mode is set.
    76  		// see: https://dev.mysql.com/doc/refman/5.7/en/function-resolution.html
    77  		{` cast( a as signed ) `, "CAST(`a` AS SIGNED)"},
    78  		{` cast( a as unsigned integer) `, "CAST(`a` AS UNSIGNED)"},
    79  		{` cast( a as char(3) binary) `, "CAST(`a` AS BINARY(3))"},
    80  		{` cast( a as decimal ) `, "CAST(`a` AS DECIMAL(10))"},
    81  		{` cast( a as decimal (3) ) `, "CAST(`a` AS DECIMAL(3))"},
    82  		{` cast( a as decimal (3,3) ) `, "CAST(`a` AS DECIMAL(3, 3))"},
    83  		{` ((case when (c0 = 0) then 0 when (c0 > 0) then (c1 / c0) end)) `, "((CASE WHEN (`c0` = 0) THEN 0 WHEN (`c0` > 0) THEN (`c1` / `c0`) END))"},
    84  		{` convert (a, signed) `, "CONVERT(`a`, SIGNED)"},
    85  		{` binary "hello"`, `BINARY "hello"`},
    86  	}
    87  	for _, tt := range testcases {
    88  		expr := fmt.Sprintf("select %s", tt.input)
    89  		charset, collation := getDefaultCharsetAndCollate()
    90  		stmts, _, err := parser.New().Parse(expr, charset, collation)
    91  		node := stmts[0].(*ast.SelectStmt).Fields.Fields[0].Expr
    92  		require.NoError(t, err)
    93  
    94  		writer := bytes.NewBufferString("")
    95  		node.Format(writer)
    96  		require.Equal(t, tt.output, writer.String())
    97  	}
    98  }