github.com/XiaoMi/Gaea@v1.2.5/parser/ast/flag_test.go (about)

     1  // Copyright 2016 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package ast_test
    15  
    16  import (
    17  	"testing"
    18  
    19  	. "github.com/pingcap/check"
    20  
    21  	"github.com/XiaoMi/Gaea/parser"
    22  	"github.com/XiaoMi/Gaea/parser/ast"
    23  )
    24  
    25  func TestT(t *testing.T) {
    26  	CustomVerboseFlag = true
    27  	TestingT(t)
    28  }
    29  
    30  var _ = Suite(&testFlagSuite{})
    31  
    32  type testFlagSuite struct {
    33  	*parser.Parser
    34  }
    35  
    36  func (ts *testFlagSuite) SetUpSuite(c *C) {
    37  	ts.Parser = parser.New()
    38  }
    39  
    40  func (ts *testFlagSuite) TestHasAggFlag(c *C) {
    41  	expr := &ast.BetweenExpr{}
    42  	flagTests := []struct {
    43  		flag   uint64
    44  		hasAgg bool
    45  	}{
    46  		{ast.FlagHasAggregateFunc, true},
    47  		{ast.FlagHasAggregateFunc | ast.FlagHasVariable, true},
    48  		{ast.FlagHasVariable, false},
    49  	}
    50  	for _, tt := range flagTests {
    51  		expr.SetFlag(tt.flag)
    52  		c.Assert(ast.HasAggFlag(expr), Equals, tt.hasAgg)
    53  	}
    54  }
    55  
    56  func (ts *testFlagSuite) TestFlag(c *C) {
    57  	flagTests := []struct {
    58  		expr string
    59  		flag uint64
    60  	}{
    61  		{
    62  			"1 between 0 and 2",
    63  			ast.FlagConstant,
    64  		},
    65  		{
    66  			"case 1 when 1 then 1 else 0 end",
    67  			ast.FlagConstant,
    68  		},
    69  		{
    70  			"case 1 when 1 then 1 else 0 end",
    71  			ast.FlagConstant,
    72  		},
    73  		{
    74  			"case 1 when a > 1 then 1 else 0 end",
    75  			ast.FlagConstant | ast.FlagHasReference,
    76  		},
    77  		{
    78  			"1 = ANY (select 1) OR exists (select 1)",
    79  			ast.FlagHasSubquery,
    80  		},
    81  		{
    82  			"1 in (1) or 1 is true or null is null or 'abc' like 'abc' or 'abc' rlike 'abc'",
    83  			ast.FlagConstant,
    84  		},
    85  		{
    86  			"row (1, 1) = row (1, 1)",
    87  			ast.FlagConstant,
    88  		},
    89  		{
    90  			"(1 + a) > ?",
    91  			ast.FlagHasReference | ast.FlagHasParamMarker,
    92  		},
    93  		{
    94  			"trim('abc ')",
    95  			ast.FlagHasFunc,
    96  		},
    97  		{
    98  			"now() + EXTRACT(YEAR FROM '2009-07-02') + CAST(1 AS UNSIGNED)",
    99  			ast.FlagHasFunc,
   100  		},
   101  		{
   102  			"substring('abc', 1)",
   103  			ast.FlagHasFunc,
   104  		},
   105  		{
   106  			"sum(a)",
   107  			ast.FlagHasAggregateFunc | ast.FlagHasReference,
   108  		},
   109  		{
   110  			"(select 1) as a",
   111  			ast.FlagHasSubquery,
   112  		},
   113  		{
   114  			"@auto_commit",
   115  			ast.FlagHasVariable,
   116  		},
   117  		{
   118  			"default(a)",
   119  			ast.FlagHasDefault,
   120  		},
   121  		{
   122  			"a is null",
   123  			ast.FlagHasReference,
   124  		},
   125  		{
   126  			"1 is true",
   127  			ast.FlagConstant,
   128  		},
   129  		{
   130  			"a in (1, count(*), 3)",
   131  			ast.FlagConstant | ast.FlagHasReference | ast.FlagHasAggregateFunc,
   132  		},
   133  		{
   134  			"'Michael!' REGEXP '.*'",
   135  			ast.FlagConstant,
   136  		},
   137  		{
   138  			"a REGEXP '.*'",
   139  			ast.FlagHasReference,
   140  		},
   141  		{
   142  			"-a",
   143  			ast.FlagHasReference,
   144  		},
   145  	}
   146  	for _, tt := range flagTests {
   147  		stmt, err := ts.ParseOneStmt("select "+tt.expr, "", "")
   148  		c.Assert(err, IsNil)
   149  		selectStmt := stmt.(*ast.SelectStmt)
   150  		ast.SetFlag(selectStmt)
   151  		expr := selectStmt.Fields.Fields[0].Expr
   152  		c.Assert(expr.GetFlag(), Equals, tt.flag, Commentf("For %s", tt.expr))
   153  	}
   154  }