github.com/XiaoMi/Gaea@v1.2.5/parser/ast/flag_test.go (about) 1 // Copyright 2016 PingCAP, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 package ast_test 15 16 import ( 17 "testing" 18 19 . "github.com/pingcap/check" 20 21 "github.com/XiaoMi/Gaea/parser" 22 "github.com/XiaoMi/Gaea/parser/ast" 23 ) 24 25 func TestT(t *testing.T) { 26 CustomVerboseFlag = true 27 TestingT(t) 28 } 29 30 var _ = Suite(&testFlagSuite{}) 31 32 type testFlagSuite struct { 33 *parser.Parser 34 } 35 36 func (ts *testFlagSuite) SetUpSuite(c *C) { 37 ts.Parser = parser.New() 38 } 39 40 func (ts *testFlagSuite) TestHasAggFlag(c *C) { 41 expr := &ast.BetweenExpr{} 42 flagTests := []struct { 43 flag uint64 44 hasAgg bool 45 }{ 46 {ast.FlagHasAggregateFunc, true}, 47 {ast.FlagHasAggregateFunc | ast.FlagHasVariable, true}, 48 {ast.FlagHasVariable, false}, 49 } 50 for _, tt := range flagTests { 51 expr.SetFlag(tt.flag) 52 c.Assert(ast.HasAggFlag(expr), Equals, tt.hasAgg) 53 } 54 } 55 56 func (ts *testFlagSuite) TestFlag(c *C) { 57 flagTests := []struct { 58 expr string 59 flag uint64 60 }{ 61 { 62 "1 between 0 and 2", 63 ast.FlagConstant, 64 }, 65 { 66 "case 1 when 1 then 1 else 0 end", 67 ast.FlagConstant, 68 }, 69 { 70 "case 1 when 1 then 1 else 0 end", 71 ast.FlagConstant, 72 }, 73 { 74 "case 1 when a > 1 then 1 else 0 end", 75 ast.FlagConstant | ast.FlagHasReference, 76 }, 77 { 78 "1 = ANY (select 1) OR exists (select 1)", 79 ast.FlagHasSubquery, 80 }, 81 { 82 "1 in (1) or 1 is true or null is null or 'abc' like 'abc' or 'abc' rlike 'abc'", 83 ast.FlagConstant, 84 }, 85 { 86 "row (1, 1) = row (1, 1)", 87 ast.FlagConstant, 88 }, 89 { 90 "(1 + a) > ?", 91 ast.FlagHasReference | ast.FlagHasParamMarker, 92 }, 93 { 94 "trim('abc ')", 95 ast.FlagHasFunc, 96 }, 97 { 98 "now() + EXTRACT(YEAR FROM '2009-07-02') + CAST(1 AS UNSIGNED)", 99 ast.FlagHasFunc, 100 }, 101 { 102 "substring('abc', 1)", 103 ast.FlagHasFunc, 104 }, 105 { 106 "sum(a)", 107 ast.FlagHasAggregateFunc | ast.FlagHasReference, 108 }, 109 { 110 "(select 1) as a", 111 ast.FlagHasSubquery, 112 }, 113 { 114 "@auto_commit", 115 ast.FlagHasVariable, 116 }, 117 { 118 "default(a)", 119 ast.FlagHasDefault, 120 }, 121 { 122 "a is null", 123 ast.FlagHasReference, 124 }, 125 { 126 "1 is true", 127 ast.FlagConstant, 128 }, 129 { 130 "a in (1, count(*), 3)", 131 ast.FlagConstant | ast.FlagHasReference | ast.FlagHasAggregateFunc, 132 }, 133 { 134 "'Michael!' REGEXP '.*'", 135 ast.FlagConstant, 136 }, 137 { 138 "a REGEXP '.*'", 139 ast.FlagHasReference, 140 }, 141 { 142 "-a", 143 ast.FlagHasReference, 144 }, 145 } 146 for _, tt := range flagTests { 147 stmt, err := ts.ParseOneStmt("select "+tt.expr, "", "") 148 c.Assert(err, IsNil) 149 selectStmt := stmt.(*ast.SelectStmt) 150 ast.SetFlag(selectStmt) 151 expr := selectStmt.Fields.Fields[0].Expr 152 c.Assert(expr.GetFlag(), Equals, tt.flag, Commentf("For %s", tt.expr)) 153 } 154 }