github.com/mithrandie/csvq@v1.18.1/lib/query/field_analyzer.go (about)

     1  package query
     2  
     3  import (
     4  	"strings"
     5  
     6  	"github.com/mithrandie/csvq/lib/parser"
     7  )
     8  
     9  func HasAggregateFunction(expr parser.QueryExpression, scope *ReferenceScope) (bool, error) {
    10  	switch expr.(type) {
    11  	case parser.AggregateFunction, parser.ListFunction:
    12  		return true, nil
    13  	case parser.Function:
    14  		e := expr.(parser.Function)
    15  		if strings.ToUpper(e.Name) == "JSON_OBJECT" {
    16  			return false, nil
    17  		}
    18  
    19  		if udfn, err := scope.GetFunction(expr, expr.(parser.Function).Name); err == nil && udfn.IsAggregate {
    20  			return true, nil
    21  		}
    22  
    23  		return HasAggregateFunctionInList(e.Args, scope)
    24  	case parser.PrimitiveType, parser.FieldReference, parser.ColumnNumber, parser.Subquery, parser.Exists,
    25  		parser.Variable, parser.EnvironmentVariable, parser.RuntimeInformation, parser.Constant, parser.Flag,
    26  		parser.CursorStatus, parser.CursorAttrebute, parser.Placeholder,
    27  		parser.AllColumns:
    28  		return false, nil
    29  	case parser.Parentheses:
    30  		return HasAggregateFunction(expr.(parser.Parentheses).Expr, scope)
    31  	case parser.Arithmetic:
    32  		e := expr.(parser.Arithmetic)
    33  		return HasAggregateFunctionInList([]parser.QueryExpression{e.LHS, e.RHS}, scope)
    34  	case parser.UnaryArithmetic:
    35  		return HasAggregateFunction(expr.(parser.UnaryArithmetic).Operand, scope)
    36  	case parser.Concat:
    37  		return HasAggregateFunctionInList(expr.(parser.Concat).Items, scope)
    38  	case parser.Comparison:
    39  		e := expr.(parser.Comparison)
    40  		return HasAggregateFunctionInList([]parser.QueryExpression{e.LHS, e.RHS}, scope)
    41  	case parser.Is:
    42  		e := expr.(parser.Is)
    43  		return HasAggregateFunctionInList([]parser.QueryExpression{e.LHS, e.RHS}, scope)
    44  	case parser.Between:
    45  		e := expr.(parser.Between)
    46  		return HasAggregateFunctionInList([]parser.QueryExpression{e.LHS, e.Low, e.High}, scope)
    47  	case parser.Like:
    48  		e := expr.(parser.Like)
    49  		return HasAggregateFunctionInList([]parser.QueryExpression{e.LHS, e.Pattern}, scope)
    50  	case parser.In:
    51  		e := expr.(parser.In)
    52  		return hasAggFuncInRowValueComparison(e.LHS, e.Values, scope)
    53  	case parser.Any:
    54  		e := expr.(parser.Any)
    55  		return hasAggFuncInRowValueComparison(e.LHS, e.Values, scope)
    56  	case parser.All:
    57  		e := expr.(parser.All)
    58  		return hasAggFuncInRowValueComparison(e.LHS, e.Values, scope)
    59  	case parser.AnalyticFunction:
    60  		e := expr.(parser.AnalyticFunction)
    61  		values := make([]parser.QueryExpression, 0, len(e.Args)+2)
    62  		values = append(values, e.Args...)
    63  
    64  		if e.AnalyticClause.PartitionClause != nil {
    65  			values = append(values, e.AnalyticClause.PartitionClause.(parser.PartitionClause).Values...)
    66  		}
    67  		if e.AnalyticClause.OrderByClause != nil {
    68  			values = append(values, GetValuesInOrderByClause(e.AnalyticClause.OrderByClause.(parser.OrderByClause))...)
    69  		}
    70  
    71  		return HasAggregateFunctionInList(values, scope)
    72  	case parser.CaseExpr:
    73  		e := expr.(parser.CaseExpr)
    74  		values := make([]parser.QueryExpression, 0, len(e.When)+2)
    75  		if e.Value != nil {
    76  			values = append(values, e.Value)
    77  		}
    78  
    79  		for _, v := range e.When {
    80  			w := v.(parser.CaseExprWhen)
    81  			values = append(values, w.Condition, w.Result)
    82  		}
    83  
    84  		if e.Else != nil {
    85  			values = append(values, e.Else.(parser.CaseExprElse).Result)
    86  		}
    87  
    88  		return HasAggregateFunctionInList(values, scope)
    89  	case parser.Logic:
    90  		e := expr.(parser.Logic)
    91  		return HasAggregateFunctionInList([]parser.QueryExpression{e.LHS, e.RHS}, scope)
    92  	case parser.UnaryLogic:
    93  		return HasAggregateFunction(expr.(parser.UnaryLogic).Operand, scope)
    94  	case parser.VariableSubstitution:
    95  		return HasAggregateFunction(expr.(parser.VariableSubstitution).Value, scope)
    96  	default:
    97  		return false, NewInvalidValueExpressionError(expr)
    98  	}
    99  }
   100  
   101  func HasAggregateFunctionInList(list []parser.QueryExpression, scope *ReferenceScope) (bool, error) {
   102  	for _, op := range list {
   103  		ok, err := HasAggregateFunction(op, scope)
   104  		if err != nil {
   105  			return false, err
   106  		}
   107  		if ok {
   108  			return true, nil
   109  		}
   110  	}
   111  	return false, nil
   112  }
   113  
   114  func GetValuesInOrderByClause(e parser.OrderByClause) []parser.QueryExpression {
   115  	values := make([]parser.QueryExpression, 0, len(e.Items))
   116  	for _, v := range e.Items {
   117  		values = append(values, v.(parser.OrderItem).Value)
   118  	}
   119  	return values
   120  }
   121  
   122  func hasAggFuncInRowValueComparison(lhs parser.QueryExpression, values parser.QueryExpression, scope *ReferenceScope) (bool, error) {
   123  	val, err := hasAggFuncInRowValue(lhs, scope)
   124  	if err != nil {
   125  		return false, err
   126  	}
   127  	if val {
   128  		return true, nil
   129  	}
   130  
   131  	return hasAggFuncInRowValue(values, scope)
   132  }
   133  
   134  func hasAggFuncInRowValue(expr parser.QueryExpression, scope *ReferenceScope) (bool, error) {
   135  	switch expr.(type) {
   136  	case parser.Subquery, parser.JsonQuery:
   137  		return false, nil
   138  	case parser.ValueList:
   139  		return HasAggregateFunctionInList(expr.(parser.ValueList).Values, scope)
   140  	case parser.RowValue:
   141  		return hasAggFuncInRowValue(expr.(parser.RowValue).Value, scope)
   142  	case parser.RowValueList:
   143  		e := expr.(parser.RowValueList)
   144  		for _, v := range e.RowValues {
   145  			ok, err := hasAggFuncInRowValue(v, scope)
   146  			if err != nil {
   147  				return false, err
   148  			}
   149  			if ok {
   150  				return true, nil
   151  			}
   152  		}
   153  		return false, nil
   154  	default:
   155  		return HasAggregateFunction(expr, scope)
   156  	}
   157  }
   158  
   159  func SearchAnalyticFunctions(expr parser.QueryExpression) ([]parser.AnalyticFunction, error) {
   160  	switch expr.(type) {
   161  	case parser.AnalyticFunction:
   162  		e := expr.(parser.AnalyticFunction)
   163  		values := make([]parser.QueryExpression, 0, len(e.Args)+2)
   164  		values = append(values, e.Args...)
   165  
   166  		if e.AnalyticClause.PartitionClause != nil {
   167  			values = append(values, e.AnalyticClause.PartitionClause.(parser.PartitionClause).Values...)
   168  		}
   169  		if e.AnalyticClause.OrderByClause != nil {
   170  			values = append(values, GetValuesInOrderByClause(e.AnalyticClause.OrderByClause.(parser.OrderByClause))...)
   171  		}
   172  
   173  		childFuncs, err := SearchAnalyticFunctionsInList(values)
   174  		if err != nil {
   175  			return nil, err
   176  		}
   177  
   178  		return appendAnalyticFunctionToListIfNotExist(childFuncs, []parser.AnalyticFunction{e}), nil
   179  	case parser.PrimitiveType, parser.FieldReference, parser.ColumnNumber, parser.Subquery, parser.Exists,
   180  		parser.Variable, parser.EnvironmentVariable, parser.RuntimeInformation, parser.Constant, parser.Flag,
   181  		parser.CursorStatus, parser.CursorAttrebute, parser.Placeholder,
   182  		parser.AllColumns:
   183  		return nil, nil
   184  	case parser.Parentheses:
   185  		return SearchAnalyticFunctions(expr.(parser.Parentheses).Expr)
   186  	case parser.Arithmetic:
   187  		e := expr.(parser.Arithmetic)
   188  		return SearchAnalyticFunctionsInList([]parser.QueryExpression{e.LHS, e.RHS})
   189  	case parser.UnaryArithmetic:
   190  		return SearchAnalyticFunctions(expr.(parser.UnaryArithmetic).Operand)
   191  	case parser.Concat:
   192  		return SearchAnalyticFunctionsInList(expr.(parser.Concat).Items)
   193  	case parser.Comparison:
   194  		e := expr.(parser.Comparison)
   195  		return SearchAnalyticFunctionsInList([]parser.QueryExpression{e.LHS, e.RHS})
   196  	case parser.Is:
   197  		e := expr.(parser.Is)
   198  		return SearchAnalyticFunctionsInList([]parser.QueryExpression{e.LHS, e.RHS})
   199  	case parser.Between:
   200  		e := expr.(parser.Between)
   201  		return SearchAnalyticFunctionsInList([]parser.QueryExpression{e.LHS, e.Low, e.High})
   202  	case parser.Like:
   203  		e := expr.(parser.Like)
   204  		return SearchAnalyticFunctionsInList([]parser.QueryExpression{e.LHS, e.Pattern})
   205  	case parser.In:
   206  		e := expr.(parser.In)
   207  		return searchAnalyticFunctionsInRowValueComparison(e.LHS, e.Values)
   208  	case parser.Any:
   209  		e := expr.(parser.Any)
   210  		return searchAnalyticFunctionsInRowValueComparison(e.LHS, e.Values)
   211  	case parser.All:
   212  		e := expr.(parser.All)
   213  		return searchAnalyticFunctionsInRowValueComparison(e.LHS, e.Values)
   214  	case parser.Function:
   215  		if strings.ToUpper(expr.(parser.Function).Name) == "JSON_OBJECT" {
   216  			return nil, nil
   217  		}
   218  		return SearchAnalyticFunctionsInList(expr.(parser.Function).Args)
   219  	case parser.AggregateFunction:
   220  		return SearchAnalyticFunctionsInList(expr.(parser.AggregateFunction).Args)
   221  	case parser.ListFunction:
   222  		return SearchAnalyticFunctionsInList(expr.(parser.ListFunction).Args)
   223  	case parser.CaseExpr:
   224  		e := expr.(parser.CaseExpr)
   225  		values := make([]parser.QueryExpression, 0, len(e.When)+2)
   226  		if e.Value != nil {
   227  			values = append(values, e.Value)
   228  		}
   229  
   230  		for _, v := range e.When {
   231  			w := v.(parser.CaseExprWhen)
   232  			values = append(values, w.Condition, w.Result)
   233  		}
   234  
   235  		if e.Else != nil {
   236  			values = append(values, e.Else.(parser.CaseExprElse).Result)
   237  		}
   238  
   239  		return SearchAnalyticFunctionsInList(values)
   240  	case parser.Logic:
   241  		e := expr.(parser.Logic)
   242  		return SearchAnalyticFunctionsInList([]parser.QueryExpression{e.LHS, e.RHS})
   243  	case parser.UnaryLogic:
   244  		return SearchAnalyticFunctions(expr.(parser.UnaryLogic).Operand)
   245  	case parser.VariableSubstitution:
   246  		return SearchAnalyticFunctions(expr.(parser.VariableSubstitution).Value)
   247  	default:
   248  		return nil, NewInvalidValueExpressionError(expr)
   249  	}
   250  }
   251  
   252  func SearchAnalyticFunctionsInList(list []parser.QueryExpression) ([]parser.AnalyticFunction, error) {
   253  	var funcs []parser.AnalyticFunction = nil
   254  	for _, op := range list {
   255  		children, err := SearchAnalyticFunctions(op)
   256  		if err != nil {
   257  			return funcs, err
   258  		}
   259  		if children != nil {
   260  			funcs = appendAnalyticFunctionToListIfNotExist(children, funcs)
   261  		}
   262  	}
   263  	return funcs, nil
   264  }
   265  
   266  func appendAnalyticFunctionToListIfNotExist(list1 []parser.AnalyticFunction, list2 []parser.AnalyticFunction) []parser.AnalyticFunction {
   267  	var createMap = func(list []parser.AnalyticFunction) map[string]parser.AnalyticFunction {
   268  		m := make(map[string]parser.AnalyticFunction, len(list))
   269  		for _, v := range list {
   270  			m[FormatFieldIdentifier(v)] = v
   271  		}
   272  		return m
   273  	}
   274  
   275  	m1 := createMap(list1)
   276  	m2 := createMap(list2)
   277  	for k, v := range m2 {
   278  		if _, ok := m1[k]; !ok {
   279  			list1 = append(list1, v)
   280  		}
   281  	}
   282  
   283  	return list1
   284  }
   285  
   286  func searchAnalyticFunctionsInRowValueComparison(lhs parser.QueryExpression, values parser.QueryExpression) ([]parser.AnalyticFunction, error) {
   287  	var funcs []parser.AnalyticFunction = nil
   288  
   289  	children, err := searchAnalyticFunctionsInRowValue(lhs)
   290  	if err != nil {
   291  		return funcs, err
   292  	}
   293  	if children != nil {
   294  		funcs = appendAnalyticFunctionToListIfNotExist(children, funcs)
   295  	}
   296  
   297  	childrenInValues, err := searchAnalyticFunctionsInRowValue(values)
   298  	if err != nil {
   299  		return funcs, err
   300  	}
   301  	if childrenInValues != nil {
   302  		funcs = appendAnalyticFunctionToListIfNotExist(childrenInValues, funcs)
   303  	}
   304  
   305  	return funcs, nil
   306  }
   307  
   308  func searchAnalyticFunctionsInRowValue(expr parser.QueryExpression) ([]parser.AnalyticFunction, error) {
   309  	switch expr.(type) {
   310  	case parser.Subquery, parser.JsonQuery:
   311  		return nil, nil
   312  	case parser.ValueList:
   313  		e := expr.(parser.ValueList)
   314  		return SearchAnalyticFunctionsInList(e.Values)
   315  	case parser.RowValue:
   316  		return searchAnalyticFunctionsInRowValue(expr.(parser.RowValue).Value)
   317  	case parser.RowValueList:
   318  		var funcs []parser.AnalyticFunction = nil
   319  
   320  		e := expr.(parser.RowValueList)
   321  		for _, v := range e.RowValues {
   322  			children, err := searchAnalyticFunctionsInRowValue(v)
   323  			if err != nil {
   324  				return funcs, err
   325  			}
   326  			if children != nil {
   327  				funcs = appendAnalyticFunctionToListIfNotExist(children, funcs)
   328  			}
   329  		}
   330  
   331  		return funcs, nil
   332  	default:
   333  		return SearchAnalyticFunctions(expr)
   334  	}
   335  }