github.com/minio/minio@v0.0.0-20240328213742-3f72439b8a27/internal/s3select/sql/aggregation.go (about)

     1  // Copyright (c) 2015-2021 MinIO, Inc.
     2  //
     3  // This file is part of MinIO Object Storage stack
     4  //
     5  // This program is free software: you can redistribute it and/or modify
     6  // it under the terms of the GNU Affero General Public License as published by
     7  // the Free Software Foundation, either version 3 of the License, or
     8  // (at your option) any later version.
     9  //
    10  // This program is distributed in the hope that it will be useful
    11  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    12  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13  // GNU Affero General Public License for more details.
    14  //
    15  // You should have received a copy of the GNU Affero General Public License
    16  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17  
    18  package sql
    19  
    20  import (
    21  	"errors"
    22  	"fmt"
    23  )
    24  
    25  // Aggregation Function name constants
    26  const (
    27  	aggFnAvg   FuncName = "AVG"
    28  	aggFnCount FuncName = "COUNT"
    29  	aggFnMax   FuncName = "MAX"
    30  	aggFnMin   FuncName = "MIN"
    31  	aggFnSum   FuncName = "SUM"
    32  )
    33  
    34  var (
    35  	errNonNumericArg = func(fnStr FuncName) error {
    36  		return fmt.Errorf("%s() requires a numeric argument", fnStr)
    37  	}
    38  	errInvalidAggregation = errors.New("Invalid aggregation seen")
    39  )
    40  
    41  type aggVal struct {
    42  	runningSum             *Value
    43  	runningCount           int64
    44  	runningMax, runningMin *Value
    45  
    46  	// Stores if at least one record has been seen
    47  	seen bool
    48  }
    49  
    50  func newAggVal(fn FuncName) *aggVal {
    51  	switch fn {
    52  	case aggFnAvg, aggFnSum:
    53  		return &aggVal{runningSum: FromFloat(0)}
    54  	case aggFnMin:
    55  		return &aggVal{runningMin: FromInt(0)}
    56  	case aggFnMax:
    57  		return &aggVal{runningMax: FromInt(0)}
    58  	default:
    59  		return &aggVal{}
    60  	}
    61  }
    62  
    63  // evalAggregationNode - performs partial computation using the
    64  // current row and stores the result.
    65  //
    66  // On success, it returns (nil, nil).
    67  func (e *FuncExpr) evalAggregationNode(r Record, tableAlias string) error {
    68  	// It is assumed that this function is called only when
    69  	// `e` is an aggregation function.
    70  
    71  	var val *Value
    72  	var err error
    73  	funcName := e.getFunctionName()
    74  	if aggFnCount == funcName {
    75  		if e.Count.StarArg {
    76  			// Handle COUNT(*)
    77  			e.aggregate.runningCount++
    78  			return nil
    79  		}
    80  
    81  		val, err = e.Count.ExprArg.evalNode(r, tableAlias)
    82  		if err != nil {
    83  			return err
    84  		}
    85  	} else {
    86  		// Evaluate the (only) argument
    87  		val, err = e.SFunc.ArgsList[0].evalNode(r, tableAlias)
    88  		if err != nil {
    89  			return err
    90  		}
    91  	}
    92  
    93  	if val.IsNull() {
    94  		// E.g. the column or field does not exist in the
    95  		// record - in all such cases the aggregation is not
    96  		// updated.
    97  		return nil
    98  	}
    99  
   100  	argVal := val
   101  	if funcName != aggFnCount {
   102  		// All aggregation functions, except COUNT require a
   103  		// numeric argument.
   104  
   105  		// Here, we diverge from Amazon S3 behavior by
   106  		// inferring untyped values are numbers.
   107  		if !argVal.isNumeric() {
   108  			if i, ok := argVal.bytesToInt(); ok {
   109  				argVal.setInt(i)
   110  			} else if f, ok := argVal.bytesToFloat(); ok {
   111  				argVal.setFloat(f)
   112  			} else {
   113  				return errNonNumericArg(funcName)
   114  			}
   115  		}
   116  	}
   117  
   118  	// Mark that we have seen one non-null value.
   119  	isFirstRow := false
   120  	if !e.aggregate.seen {
   121  		e.aggregate.seen = true
   122  		isFirstRow = true
   123  	}
   124  
   125  	switch funcName {
   126  	case aggFnCount:
   127  		// For all non-null values, the count is incremented.
   128  		e.aggregate.runningCount++
   129  
   130  	case aggFnAvg, aggFnSum:
   131  		e.aggregate.runningCount++
   132  		// Convert to float.
   133  		f, ok := argVal.ToFloat()
   134  		if !ok {
   135  			return fmt.Errorf("Could not convert value %v (%s) to a number", argVal.value, argVal.GetTypeString())
   136  		}
   137  		argVal.setFloat(f)
   138  		err = e.aggregate.runningSum.arithOp(opPlus, argVal)
   139  
   140  	case aggFnMin:
   141  		err = e.aggregate.runningMin.minmax(argVal, false, isFirstRow)
   142  
   143  	case aggFnMax:
   144  		err = e.aggregate.runningMax.minmax(argVal, true, isFirstRow)
   145  
   146  	default:
   147  		err = errInvalidAggregation
   148  	}
   149  
   150  	return err
   151  }
   152  
   153  func (e *AliasedExpression) aggregateRow(r Record, tableAlias string) error {
   154  	return e.Expression.aggregateRow(r, tableAlias)
   155  }
   156  
   157  func (e *Expression) aggregateRow(r Record, tableAlias string) error {
   158  	for _, ex := range e.And {
   159  		err := ex.aggregateRow(r, tableAlias)
   160  		if err != nil {
   161  			return err
   162  		}
   163  	}
   164  	return nil
   165  }
   166  
   167  func (e *ListExpr) aggregateRow(r Record, tableAlias string) error {
   168  	for _, ex := range e.Elements {
   169  		err := ex.aggregateRow(r, tableAlias)
   170  		if err != nil {
   171  			return err
   172  		}
   173  	}
   174  	return nil
   175  }
   176  
   177  func (e *AndCondition) aggregateRow(r Record, tableAlias string) error {
   178  	for _, ex := range e.Condition {
   179  		err := ex.aggregateRow(r, tableAlias)
   180  		if err != nil {
   181  			return err
   182  		}
   183  	}
   184  	return nil
   185  }
   186  
   187  func (e *Condition) aggregateRow(r Record, tableAlias string) error {
   188  	if e.Operand != nil {
   189  		return e.Operand.aggregateRow(r, tableAlias)
   190  	}
   191  	return e.Not.aggregateRow(r, tableAlias)
   192  }
   193  
   194  func (e *ConditionOperand) aggregateRow(r Record, tableAlias string) error {
   195  	err := e.Operand.aggregateRow(r, tableAlias)
   196  	if err != nil {
   197  		return err
   198  	}
   199  
   200  	if e.ConditionRHS == nil {
   201  		return nil
   202  	}
   203  
   204  	switch {
   205  	case e.ConditionRHS.Compare != nil:
   206  		return e.ConditionRHS.Compare.Operand.aggregateRow(r, tableAlias)
   207  	case e.ConditionRHS.Between != nil:
   208  		err = e.ConditionRHS.Between.Start.aggregateRow(r, tableAlias)
   209  		if err != nil {
   210  			return err
   211  		}
   212  		return e.ConditionRHS.Between.End.aggregateRow(r, tableAlias)
   213  	case e.ConditionRHS.In != nil:
   214  		if e.ConditionRHS.In.ListExpr != nil {
   215  			return e.ConditionRHS.In.ListExpr.aggregateRow(r, tableAlias)
   216  		}
   217  		return nil
   218  	case e.ConditionRHS.Like != nil:
   219  		err = e.ConditionRHS.Like.Pattern.aggregateRow(r, tableAlias)
   220  		if err != nil {
   221  			return err
   222  		}
   223  		return e.ConditionRHS.Like.EscapeChar.aggregateRow(r, tableAlias)
   224  	default:
   225  		return errInvalidASTNode
   226  	}
   227  }
   228  
   229  func (e *Operand) aggregateRow(r Record, tableAlias string) error {
   230  	err := e.Left.aggregateRow(r, tableAlias)
   231  	if err != nil {
   232  		return err
   233  	}
   234  	for _, rt := range e.Right {
   235  		err = rt.Right.aggregateRow(r, tableAlias)
   236  		if err != nil {
   237  			return err
   238  		}
   239  	}
   240  	return nil
   241  }
   242  
   243  func (e *MultOp) aggregateRow(r Record, tableAlias string) error {
   244  	err := e.Left.aggregateRow(r, tableAlias)
   245  	if err != nil {
   246  		return err
   247  	}
   248  	for _, rt := range e.Right {
   249  		err = rt.Right.aggregateRow(r, tableAlias)
   250  		if err != nil {
   251  			return err
   252  		}
   253  	}
   254  	return nil
   255  }
   256  
   257  func (e *UnaryTerm) aggregateRow(r Record, tableAlias string) error {
   258  	if e.Negated != nil {
   259  		return e.Negated.Term.aggregateRow(r, tableAlias)
   260  	}
   261  	return e.Primary.aggregateRow(r, tableAlias)
   262  }
   263  
   264  func (e *PrimaryTerm) aggregateRow(r Record, tableAlias string) error {
   265  	switch {
   266  	case e.ListExpr != nil:
   267  		return e.ListExpr.aggregateRow(r, tableAlias)
   268  	case e.SubExpression != nil:
   269  		return e.SubExpression.aggregateRow(r, tableAlias)
   270  	case e.FuncCall != nil:
   271  		return e.FuncCall.aggregateRow(r, tableAlias)
   272  	}
   273  	return nil
   274  }
   275  
   276  func (e *FuncExpr) aggregateRow(r Record, tableAlias string) error {
   277  	switch e.getFunctionName() {
   278  	case aggFnAvg, aggFnSum, aggFnMax, aggFnMin, aggFnCount:
   279  		return e.evalAggregationNode(r, tableAlias)
   280  	default:
   281  		// TODO: traverse arguments and call aggregateRow on
   282  		// them if they could be an ancestor of an
   283  		// aggregation.
   284  	}
   285  	return nil
   286  }
   287  
   288  // getAggregate() implementation for each AST node follows. This is
   289  // called after calling aggregateRow() on each input row, to calculate
   290  // the final aggregate result.
   291  
   292  func (e *FuncExpr) getAggregate() (*Value, error) {
   293  	switch e.getFunctionName() {
   294  	case aggFnCount:
   295  		return FromInt(e.aggregate.runningCount), nil
   296  
   297  	case aggFnAvg:
   298  		if e.aggregate.runningCount == 0 {
   299  			// No rows were seen by AVG.
   300  			return FromNull(), nil
   301  		}
   302  		err := e.aggregate.runningSum.arithOp(opDivide, FromInt(e.aggregate.runningCount))
   303  		return e.aggregate.runningSum, err
   304  
   305  	case aggFnMin:
   306  		if !e.aggregate.seen {
   307  			// No rows were seen by MIN
   308  			return FromNull(), nil
   309  		}
   310  		return e.aggregate.runningMin, nil
   311  
   312  	case aggFnMax:
   313  		if !e.aggregate.seen {
   314  			// No rows were seen by MAX
   315  			return FromNull(), nil
   316  		}
   317  		return e.aggregate.runningMax, nil
   318  
   319  	case aggFnSum:
   320  		// TODO: check if returning 0 when no rows were seen
   321  		// by SUM is expected behavior.
   322  		return e.aggregate.runningSum, nil
   323  
   324  	default:
   325  		// TODO:
   326  	}
   327  
   328  	return nil, errInvalidAggregation
   329  }