github.com/minio/minio@v0.0.0-20240328213742-3f72439b8a27/internal/s3select/sql/analysis.go (about)

     1  // Copyright (c) 2015-2021 MinIO, Inc.
     2  //
     3  // This file is part of MinIO Object Storage stack
     4  //
     5  // This program is free software: you can redistribute it and/or modify
     6  // it under the terms of the GNU Affero General Public License as published by
     7  // the Free Software Foundation, either version 3 of the License, or
     8  // (at your option) any later version.
     9  //
    10  // This program is distributed in the hope that it will be useful
    11  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    12  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13  // GNU Affero General Public License for more details.
    14  //
    15  // You should have received a copy of the GNU Affero General Public License
    16  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17  
    18  package sql
    19  
    20  import (
    21  	"errors"
    22  	"fmt"
    23  	"strings"
    24  )
    25  
    26  // Query analysis - The query is analyzed to determine if it involves
    27  // aggregation.
    28  //
    29  // Aggregation functions - An expression that involves aggregation of
    30  // rows in some manner. Requires all input rows to be processed,
    31  // before a result is returned.
    32  //
    33  // Row function - An expression that depends on a value in the
    34  // row. They have an output for each input row.
    35  //
    36  // Some types of a queries are not valid. For example, an aggregation
    37  // function combined with a row function is meaningless ("AVG(s.Age) +
    38  // s.Salary"). Analysis determines if such a scenario exists so an
    39  // error can be returned.
    40  
    41  var (
    42  	// Fatal error for query processing.
    43  	errNestedAggregation      = errors.New("Cannot nest aggregations")
    44  	errFunctionNotImplemented = errors.New("Function is not yet implemented")
    45  	errUnexpectedInvalidNode  = errors.New("Unexpected node value")
    46  	errInvalidKeypath         = errors.New("A provided keypath is invalid")
    47  )
    48  
    49  // qProp contains analysis info about an SQL term.
    50  type qProp struct {
    51  	isAggregation, isRowFunc bool
    52  
    53  	err error
    54  }
    55  
    56  // `combine` combines a pair of `qProp`s, so that errors are
    57  // propagated correctly, and checks that an aggregation is not being
    58  // combined with a row-function term.
    59  func (p *qProp) combine(q qProp) {
    60  	switch {
    61  	case p.err != nil:
    62  		// Do nothing
    63  	case q.err != nil:
    64  		p.err = q.err
    65  	default:
    66  		p.isAggregation = p.isAggregation || q.isAggregation
    67  		p.isRowFunc = p.isRowFunc || q.isRowFunc
    68  		if p.isAggregation && p.isRowFunc {
    69  			p.err = errNestedAggregation
    70  		}
    71  	}
    72  }
    73  
    74  func (e *SelectExpression) analyze(s *Select) (result qProp) {
    75  	if e.All {
    76  		return qProp{isRowFunc: true}
    77  	}
    78  
    79  	for _, ex := range e.Expressions {
    80  		result.combine(ex.analyze(s))
    81  	}
    82  	return
    83  }
    84  
    85  func (e *AliasedExpression) analyze(s *Select) qProp {
    86  	return e.Expression.analyze(s)
    87  }
    88  
    89  func (e *Expression) analyze(s *Select) (result qProp) {
    90  	for _, ac := range e.And {
    91  		result.combine(ac.analyze(s))
    92  	}
    93  	return
    94  }
    95  
    96  func (e *AndCondition) analyze(s *Select) (result qProp) {
    97  	for _, ac := range e.Condition {
    98  		result.combine(ac.analyze(s))
    99  	}
   100  	return
   101  }
   102  
   103  func (e *Condition) analyze(s *Select) (result qProp) {
   104  	if e.Operand != nil {
   105  		result = e.Operand.analyze(s)
   106  	} else {
   107  		result = e.Not.analyze(s)
   108  	}
   109  	return
   110  }
   111  
   112  func (e *ListExpr) analyze(s *Select) (result qProp) {
   113  	for _, ac := range e.Elements {
   114  		result.combine(ac.analyze(s))
   115  	}
   116  	return
   117  }
   118  
   119  func (e *ConditionOperand) analyze(s *Select) (result qProp) {
   120  	if e.ConditionRHS == nil {
   121  		result = e.Operand.analyze(s)
   122  	} else {
   123  		result.combine(e.Operand.analyze(s))
   124  		result.combine(e.ConditionRHS.analyze(s))
   125  	}
   126  	return
   127  }
   128  
   129  func (e *ConditionRHS) analyze(s *Select) (result qProp) {
   130  	switch {
   131  	case e.Compare != nil:
   132  		result = e.Compare.Operand.analyze(s)
   133  	case e.Between != nil:
   134  		result.combine(e.Between.Start.analyze(s))
   135  		result.combine(e.Between.End.analyze(s))
   136  	case e.In != nil:
   137  		result.combine(e.In.analyze(s))
   138  	case e.Like != nil:
   139  		result.combine(e.Like.Pattern.analyze(s))
   140  		if e.Like.EscapeChar != nil {
   141  			result.combine(e.Like.EscapeChar.analyze(s))
   142  		}
   143  	default:
   144  		result = qProp{err: errUnexpectedInvalidNode}
   145  	}
   146  	return
   147  }
   148  
   149  func (e *In) analyze(s *Select) (result qProp) {
   150  	switch {
   151  	case e.JPathExpr != nil:
   152  		// Check if the path expression is valid
   153  		if len(e.JPathExpr.PathExpr) > 0 {
   154  			if e.JPathExpr.BaseKey.String() != s.From.As && !strings.EqualFold(e.JPathExpr.BaseKey.String(), baseTableName) {
   155  				result = qProp{err: errInvalidKeypath}
   156  				return
   157  			}
   158  		}
   159  		result = qProp{isRowFunc: true}
   160  	case e.ListExpr != nil:
   161  		result = e.ListExpr.analyze(s)
   162  	default:
   163  		result = qProp{err: errUnexpectedInvalidNode}
   164  	}
   165  	return
   166  }
   167  
   168  func (e *Operand) analyze(s *Select) (result qProp) {
   169  	result.combine(e.Left.analyze(s))
   170  	for _, r := range e.Right {
   171  		result.combine(r.Right.analyze(s))
   172  	}
   173  	return
   174  }
   175  
   176  func (e *MultOp) analyze(s *Select) (result qProp) {
   177  	result.combine(e.Left.analyze(s))
   178  	for _, r := range e.Right {
   179  		result.combine(r.Right.analyze(s))
   180  	}
   181  	return
   182  }
   183  
   184  func (e *UnaryTerm) analyze(s *Select) (result qProp) {
   185  	if e.Negated != nil {
   186  		result = e.Negated.Term.analyze(s)
   187  	} else {
   188  		result = e.Primary.analyze(s)
   189  	}
   190  	return
   191  }
   192  
   193  func (e *PrimaryTerm) analyze(s *Select) (result qProp) {
   194  	switch {
   195  	case e.Value != nil:
   196  		result = qProp{}
   197  
   198  	case e.JPathExpr != nil:
   199  		// Check if the path expression is valid
   200  		if len(e.JPathExpr.PathExpr) > 0 {
   201  			if e.JPathExpr.BaseKey.String() != s.From.As && !strings.EqualFold(e.JPathExpr.BaseKey.String(), baseTableName) {
   202  				result = qProp{err: errInvalidKeypath}
   203  				return
   204  			}
   205  		}
   206  		result = qProp{isRowFunc: true}
   207  
   208  	case e.ListExpr != nil:
   209  		result = e.ListExpr.analyze(s)
   210  
   211  	case e.SubExpression != nil:
   212  		result = e.SubExpression.analyze(s)
   213  
   214  	case e.FuncCall != nil:
   215  		result = e.FuncCall.analyze(s)
   216  
   217  	default:
   218  		result = qProp{err: errUnexpectedInvalidNode}
   219  	}
   220  	return
   221  }
   222  
   223  func (e *FuncExpr) analyze(s *Select) (result qProp) {
   224  	funcName := e.getFunctionName()
   225  
   226  	switch funcName {
   227  	case sqlFnCast:
   228  		return e.Cast.Expr.analyze(s)
   229  
   230  	case sqlFnExtract:
   231  		return e.Extract.From.analyze(s)
   232  
   233  	case sqlFnDateAdd:
   234  		result.combine(e.DateAdd.Quantity.analyze(s))
   235  		result.combine(e.DateAdd.Timestamp.analyze(s))
   236  		return result
   237  
   238  	case sqlFnDateDiff:
   239  		result.combine(e.DateDiff.Timestamp1.analyze(s))
   240  		result.combine(e.DateDiff.Timestamp2.analyze(s))
   241  		return result
   242  
   243  	// Handle aggregation function calls
   244  	case aggFnAvg, aggFnMax, aggFnMin, aggFnSum, aggFnCount:
   245  		// Initialize accumulator
   246  		e.aggregate = newAggVal(funcName)
   247  
   248  		var exprA qProp
   249  		if funcName == aggFnCount {
   250  			if e.Count.StarArg {
   251  				return qProp{isAggregation: true}
   252  			}
   253  
   254  			exprA = e.Count.ExprArg.analyze(s)
   255  		} else {
   256  			if len(e.SFunc.ArgsList) != 1 {
   257  				return qProp{err: fmt.Errorf("%s takes exactly one argument", funcName)}
   258  			}
   259  			exprA = e.SFunc.ArgsList[0].analyze(s)
   260  		}
   261  
   262  		if exprA.err != nil {
   263  			return exprA
   264  		}
   265  		if exprA.isAggregation {
   266  			return qProp{err: errNestedAggregation}
   267  		}
   268  		return qProp{isAggregation: true}
   269  
   270  	case sqlFnCoalesce:
   271  		if len(e.SFunc.ArgsList) == 0 {
   272  			return qProp{err: fmt.Errorf("%s needs at least one argument", string(funcName))}
   273  		}
   274  		for _, arg := range e.SFunc.ArgsList {
   275  			result.combine(arg.analyze(s))
   276  		}
   277  		return result
   278  
   279  	case sqlFnNullIf:
   280  		if len(e.SFunc.ArgsList) != 2 {
   281  			return qProp{err: fmt.Errorf("%s needs exactly 2 arguments", string(funcName))}
   282  		}
   283  		for _, arg := range e.SFunc.ArgsList {
   284  			result.combine(arg.analyze(s))
   285  		}
   286  		return result
   287  
   288  	case sqlFnCharLength, sqlFnCharacterLength:
   289  		if len(e.SFunc.ArgsList) != 1 {
   290  			return qProp{err: fmt.Errorf("%s needs exactly 2 arguments", string(funcName))}
   291  		}
   292  		for _, arg := range e.SFunc.ArgsList {
   293  			result.combine(arg.analyze(s))
   294  		}
   295  		return result
   296  
   297  	case sqlFnLower, sqlFnUpper:
   298  		if len(e.SFunc.ArgsList) != 1 {
   299  			return qProp{err: fmt.Errorf("%s needs exactly 2 arguments", string(funcName))}
   300  		}
   301  		for _, arg := range e.SFunc.ArgsList {
   302  			result.combine(arg.analyze(s))
   303  		}
   304  		return result
   305  
   306  	case sqlFnTrim:
   307  		if e.Trim.TrimChars != nil {
   308  			result.combine(e.Trim.TrimChars.analyze(s))
   309  		}
   310  		if e.Trim.TrimFrom != nil {
   311  			result.combine(e.Trim.TrimFrom.analyze(s))
   312  		}
   313  		return result
   314  
   315  	case sqlFnSubstring:
   316  		errVal := fmt.Errorf("Invalid argument(s) to %s", string(funcName))
   317  		result.combine(e.Substring.Expr.analyze(s))
   318  		switch {
   319  		case e.Substring.From != nil:
   320  			result.combine(e.Substring.From.analyze(s))
   321  			if e.Substring.For != nil {
   322  				result.combine(e.Substring.Expr.analyze(s))
   323  			}
   324  		case e.Substring.Arg2 != nil:
   325  			result.combine(e.Substring.Arg2.analyze(s))
   326  			if e.Substring.Arg3 != nil {
   327  				result.combine(e.Substring.Arg3.analyze(s))
   328  			}
   329  		default:
   330  			result.err = errVal
   331  		}
   332  		return result
   333  
   334  	case sqlFnUTCNow:
   335  		if len(e.SFunc.ArgsList) != 0 {
   336  			result.err = fmt.Errorf("%s() takes no arguments", string(funcName))
   337  		}
   338  		return result
   339  	}
   340  
   341  	// TODO: implement other functions
   342  	return qProp{err: errFunctionNotImplemented}
   343  }