github.com/XiaoMi/Gaea@v1.2.5/parser/ast/flag.go (about)

     1  // Copyright 2015 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package ast
    15  
    16  // HasAggFlag checks if the expr contains FlagHasAggregateFunc.
    17  func HasAggFlag(expr ExprNode) bool {
    18  	return expr.GetFlag()&FlagHasAggregateFunc > 0
    19  }
    20  
    21  // HasWindowFlag check if the expr contains FlagHasWindowFunc
    22  func HasWindowFlag(expr ExprNode) bool {
    23  	return expr.GetFlag()&FlagHasWindowFunc > 0
    24  }
    25  
    26  // SetFlag sets flag for expression.
    27  func SetFlag(n Node) {
    28  	var setter flagSetter
    29  	n.Accept(&setter)
    30  }
    31  
    32  type flagSetter struct {
    33  }
    34  
    35  func (f *flagSetter) Enter(in Node) (Node, bool) {
    36  	return in, false
    37  }
    38  
    39  func (f *flagSetter) Leave(in Node) (Node, bool) {
    40  	if x, ok := in.(ParamMarkerExpr); ok {
    41  		x.SetFlag(FlagHasParamMarker)
    42  	}
    43  	switch x := in.(type) {
    44  	case *AggregateFuncExpr:
    45  		f.aggregateFunc(x)
    46  	case *WindowFuncExpr:
    47  		f.windowFunc(x)
    48  	case *BetweenExpr:
    49  		x.SetFlag(x.Expr.GetFlag() | x.Left.GetFlag() | x.Right.GetFlag())
    50  	case *BinaryOperationExpr:
    51  		x.SetFlag(x.L.GetFlag() | x.R.GetFlag())
    52  	case *CaseExpr:
    53  		f.caseExpr(x)
    54  	case *ColumnNameExpr:
    55  		x.SetFlag(FlagHasReference)
    56  	case *CompareSubqueryExpr:
    57  		x.SetFlag(x.L.GetFlag() | x.R.GetFlag())
    58  	case *DefaultExpr:
    59  		x.SetFlag(FlagHasDefault)
    60  	case *ExistsSubqueryExpr:
    61  		x.SetFlag(x.Sel.GetFlag())
    62  	case *FuncCallExpr:
    63  		f.funcCall(x)
    64  	case *FuncCastExpr:
    65  		x.SetFlag(FlagHasFunc | x.Expr.GetFlag())
    66  	case *IsNullExpr:
    67  		x.SetFlag(x.Expr.GetFlag())
    68  	case *IsTruthExpr:
    69  		x.SetFlag(x.Expr.GetFlag())
    70  	case *ParenthesesExpr:
    71  		x.SetFlag(x.Expr.GetFlag())
    72  	case *PatternInExpr:
    73  		f.patternIn(x)
    74  	case *PatternLikeExpr:
    75  		f.patternLike(x)
    76  	case *PatternRegexpExpr:
    77  		f.patternRegexp(x)
    78  	case *PositionExpr:
    79  		x.SetFlag(FlagHasReference)
    80  	case *RowExpr:
    81  		f.row(x)
    82  	case *SubqueryExpr:
    83  		x.SetFlag(FlagHasSubquery)
    84  	case *UnaryOperationExpr:
    85  		x.SetFlag(x.V.GetFlag())
    86  	case *ValuesExpr:
    87  		x.SetFlag(FlagHasReference)
    88  	case *VariableExpr:
    89  		if x.Value == nil {
    90  			x.SetFlag(FlagHasVariable)
    91  		} else {
    92  			x.SetFlag(FlagHasVariable | x.Value.GetFlag())
    93  		}
    94  	}
    95  
    96  	return in, true
    97  }
    98  
    99  func (f *flagSetter) caseExpr(x *CaseExpr) {
   100  	var flag uint64
   101  	if x.Value != nil {
   102  		flag |= x.Value.GetFlag()
   103  	}
   104  	for _, val := range x.WhenClauses {
   105  		flag |= val.Expr.GetFlag()
   106  		flag |= val.Result.GetFlag()
   107  	}
   108  	if x.ElseClause != nil {
   109  		flag |= x.ElseClause.GetFlag()
   110  	}
   111  	x.SetFlag(flag)
   112  }
   113  
   114  func (f *flagSetter) patternIn(x *PatternInExpr) {
   115  	flag := x.Expr.GetFlag()
   116  	for _, val := range x.List {
   117  		flag |= val.GetFlag()
   118  	}
   119  	if x.Sel != nil {
   120  		flag |= x.Sel.GetFlag()
   121  	}
   122  	x.SetFlag(flag)
   123  }
   124  
   125  func (f *flagSetter) patternLike(x *PatternLikeExpr) {
   126  	flag := x.Pattern.GetFlag()
   127  	if x.Expr != nil {
   128  		flag |= x.Expr.GetFlag()
   129  	}
   130  	x.SetFlag(flag)
   131  }
   132  
   133  func (f *flagSetter) patternRegexp(x *PatternRegexpExpr) {
   134  	flag := x.Pattern.GetFlag()
   135  	if x.Expr != nil {
   136  		flag |= x.Expr.GetFlag()
   137  	}
   138  	x.SetFlag(flag)
   139  }
   140  
   141  func (f *flagSetter) row(x *RowExpr) {
   142  	var flag uint64
   143  	for _, val := range x.Values {
   144  		flag |= val.GetFlag()
   145  	}
   146  	x.SetFlag(flag)
   147  }
   148  
   149  func (f *flagSetter) funcCall(x *FuncCallExpr) {
   150  	flag := FlagHasFunc
   151  	for _, val := range x.Args {
   152  		flag |= val.GetFlag()
   153  	}
   154  	x.SetFlag(flag)
   155  }
   156  
   157  func (f *flagSetter) aggregateFunc(x *AggregateFuncExpr) {
   158  	flag := FlagHasAggregateFunc
   159  	for _, val := range x.Args {
   160  		flag |= val.GetFlag()
   161  	}
   162  	x.SetFlag(flag)
   163  }
   164  
   165  func (f *flagSetter) windowFunc(x *WindowFuncExpr) {
   166  	flag := FlagHasWindowFunc
   167  	for _, val := range x.Args {
   168  		flag |= val.GetFlag()
   169  	}
   170  	x.SetFlag(flag)
   171  }