github.com/bingoohuang/gg@v0.0.0-20240325092523-45da7dee9335/pkg/sqlparse/tidbparser/ast/flag.go (about)

     1  // Copyright 2015 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package ast
    15  
    16  // HasAggFlag checks if the expr contains FlagHasAggregateFunc.
    17  func HasAggFlag(expr ExprNode) bool {
    18  	return expr.GetFlag()&FlagHasAggregateFunc > 0
    19  }
    20  
    21  // SetFlag sets flag for expression.
    22  func SetFlag(n Node) {
    23  	var setter flagSetter
    24  	n.Accept(&setter)
    25  }
    26  
    27  type flagSetter struct{}
    28  
    29  func (f *flagSetter) Enter(in Node) (Node, bool) {
    30  	return in, false
    31  }
    32  
    33  func (f *flagSetter) Leave(in Node) (Node, bool) {
    34  	switch x := in.(type) {
    35  	case *AggregateFuncExpr:
    36  		f.aggregateFunc(x)
    37  	case *BetweenExpr:
    38  		x.SetFlag(x.Expr.GetFlag() | x.Left.GetFlag() | x.Right.GetFlag())
    39  	case *BinaryOperationExpr:
    40  		x.SetFlag(x.L.GetFlag() | x.R.GetFlag())
    41  	case *CaseExpr:
    42  		f.caseExpr(x)
    43  	case *ColumnNameExpr:
    44  		x.SetFlag(FlagHasReference)
    45  	case *CompareSubqueryExpr:
    46  		x.SetFlag(x.L.GetFlag() | x.R.GetFlag())
    47  	case *DefaultExpr:
    48  		x.SetFlag(FlagHasDefault)
    49  	case *ExistsSubqueryExpr:
    50  		x.SetFlag(x.Sel.GetFlag())
    51  	case *FuncCallExpr:
    52  		f.funcCall(x)
    53  	case *FuncCastExpr:
    54  		x.SetFlag(FlagHasFunc | x.Expr.GetFlag())
    55  	case *IsNullExpr:
    56  		x.SetFlag(x.Expr.GetFlag())
    57  	case *IsTruthExpr:
    58  		x.SetFlag(x.Expr.GetFlag())
    59  	case *ParamMarkerExpr:
    60  		x.SetFlag(FlagHasParamMarker)
    61  	case *ParenthesesExpr:
    62  		x.SetFlag(x.Expr.GetFlag())
    63  	case *PatternInExpr:
    64  		f.patternIn(x)
    65  	case *PatternLikeExpr:
    66  		f.patternLike(x)
    67  	case *PatternRegexpExpr:
    68  		f.patternRegexp(x)
    69  	case *PositionExpr:
    70  		x.SetFlag(FlagHasReference)
    71  	case *RowExpr:
    72  		f.row(x)
    73  	case *SubqueryExpr:
    74  		x.SetFlag(FlagHasSubquery)
    75  	case *UnaryOperationExpr:
    76  		x.SetFlag(x.V.GetFlag())
    77  	case *ValueExpr:
    78  	case *ValuesExpr:
    79  		x.SetFlag(FlagHasReference)
    80  	case *VariableExpr:
    81  		if x.Value == nil {
    82  			x.SetFlag(FlagHasVariable)
    83  		} else {
    84  			x.SetFlag(FlagHasVariable | x.Value.GetFlag())
    85  		}
    86  	}
    87  
    88  	return in, true
    89  }
    90  
    91  func (f *flagSetter) caseExpr(x *CaseExpr) {
    92  	var flag uint64
    93  	if x.Value != nil {
    94  		flag |= x.Value.GetFlag()
    95  	}
    96  	for _, val := range x.WhenClauses {
    97  		flag |= val.Expr.GetFlag()
    98  		flag |= val.Result.GetFlag()
    99  	}
   100  	if x.ElseClause != nil {
   101  		flag |= x.ElseClause.GetFlag()
   102  	}
   103  	x.SetFlag(flag)
   104  }
   105  
   106  func (f *flagSetter) patternIn(x *PatternInExpr) {
   107  	flag := x.Expr.GetFlag()
   108  	for _, val := range x.List {
   109  		flag |= val.GetFlag()
   110  	}
   111  	if x.Sel != nil {
   112  		flag |= x.Sel.GetFlag()
   113  	}
   114  	x.SetFlag(flag)
   115  }
   116  
   117  func (f *flagSetter) patternLike(x *PatternLikeExpr) {
   118  	flag := x.Pattern.GetFlag()
   119  	if x.Expr != nil {
   120  		flag |= x.Expr.GetFlag()
   121  	}
   122  	x.SetFlag(flag)
   123  }
   124  
   125  func (f *flagSetter) patternRegexp(x *PatternRegexpExpr) {
   126  	flag := x.Pattern.GetFlag()
   127  	if x.Expr != nil {
   128  		flag |= x.Expr.GetFlag()
   129  	}
   130  	x.SetFlag(flag)
   131  }
   132  
   133  func (f *flagSetter) row(x *RowExpr) {
   134  	var flag uint64
   135  	for _, val := range x.Values {
   136  		flag |= val.GetFlag()
   137  	}
   138  	x.SetFlag(flag)
   139  }
   140  
   141  func (f *flagSetter) funcCall(x *FuncCallExpr) {
   142  	flag := FlagHasFunc
   143  	for _, val := range x.Args {
   144  		flag |= val.GetFlag()
   145  	}
   146  	x.SetFlag(flag)
   147  }
   148  
   149  func (f *flagSetter) aggregateFunc(x *AggregateFuncExpr) {
   150  	flag := FlagHasAggregateFunc
   151  	for _, val := range x.Args {
   152  		flag |= val.GetFlag()
   153  	}
   154  	x.SetFlag(flag)
   155  }