storj.io/minio@v0.0.0-20230509071714-0cbc90f649b1/pkg/s3select/sql/aggregation.go (about)

     1  /*
     2   * MinIO Cloud Storage, (C) 2019 MinIO, Inc.
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package sql
    18  
    19  import (
    20  	"errors"
    21  	"fmt"
    22  )
    23  
    24  // Aggregation Function name constants
    25  const (
    26  	aggFnAvg   FuncName = "AVG"
    27  	aggFnCount FuncName = "COUNT"
    28  	aggFnMax   FuncName = "MAX"
    29  	aggFnMin   FuncName = "MIN"
    30  	aggFnSum   FuncName = "SUM"
    31  )
    32  
    33  var (
    34  	errNonNumericArg = func(fnStr FuncName) error {
    35  		return fmt.Errorf("%s() requires a numeric argument", fnStr)
    36  	}
    37  	errInvalidAggregation = errors.New("Invalid aggregation seen")
    38  )
    39  
    40  type aggVal struct {
    41  	runningSum             *Value
    42  	runningCount           int64
    43  	runningMax, runningMin *Value
    44  
    45  	// Stores if at least one record has been seen
    46  	seen bool
    47  }
    48  
    49  func newAggVal(fn FuncName) *aggVal {
    50  	switch fn {
    51  	case aggFnAvg, aggFnSum:
    52  		return &aggVal{runningSum: FromFloat(0)}
    53  	case aggFnMin:
    54  		return &aggVal{runningMin: FromInt(0)}
    55  	case aggFnMax:
    56  		return &aggVal{runningMax: FromInt(0)}
    57  	default:
    58  		return &aggVal{}
    59  	}
    60  }
    61  
    62  // evalAggregationNode - performs partial computation using the
    63  // current row and stores the result.
    64  //
    65  // On success, it returns (nil, nil).
    66  func (e *FuncExpr) evalAggregationNode(r Record, tableAlias string) error {
    67  	// It is assumed that this function is called only when
    68  	// `e` is an aggregation function.
    69  
    70  	var val *Value
    71  	var err error
    72  	funcName := e.getFunctionName()
    73  	if aggFnCount == funcName {
    74  		if e.Count.StarArg {
    75  			// Handle COUNT(*)
    76  			e.aggregate.runningCount++
    77  			return nil
    78  		}
    79  
    80  		val, err = e.Count.ExprArg.evalNode(r, tableAlias)
    81  		if err != nil {
    82  			return err
    83  		}
    84  	} else {
    85  		// Evaluate the (only) argument
    86  		val, err = e.SFunc.ArgsList[0].evalNode(r, tableAlias)
    87  		if err != nil {
    88  			return err
    89  		}
    90  	}
    91  
    92  	if val.IsNull() {
    93  		// E.g. the column or field does not exist in the
    94  		// record - in all such cases the aggregation is not
    95  		// updated.
    96  		return nil
    97  	}
    98  
    99  	argVal := val
   100  	if funcName != aggFnCount {
   101  		// All aggregation functions, except COUNT require a
   102  		// numeric argument.
   103  
   104  		// Here, we diverge from Amazon S3 behavior by
   105  		// inferring untyped values are numbers.
   106  		if !argVal.isNumeric() {
   107  			if i, ok := argVal.bytesToInt(); ok {
   108  				argVal.setInt(i)
   109  			} else if f, ok := argVal.bytesToFloat(); ok {
   110  				argVal.setFloat(f)
   111  			} else {
   112  				return errNonNumericArg(funcName)
   113  			}
   114  		}
   115  	}
   116  
   117  	// Mark that we have seen one non-null value.
   118  	isFirstRow := false
   119  	if !e.aggregate.seen {
   120  		e.aggregate.seen = true
   121  		isFirstRow = true
   122  	}
   123  
   124  	switch funcName {
   125  	case aggFnCount:
   126  		// For all non-null values, the count is incremented.
   127  		e.aggregate.runningCount++
   128  
   129  	case aggFnAvg, aggFnSum:
   130  		e.aggregate.runningCount++
   131  		// Convert to float.
   132  		f, ok := argVal.ToFloat()
   133  		if !ok {
   134  			return fmt.Errorf("Could not convert value %v (%s) to a number", argVal.value, argVal.GetTypeString())
   135  		}
   136  		argVal.setFloat(f)
   137  		err = e.aggregate.runningSum.arithOp(opPlus, argVal)
   138  
   139  	case aggFnMin:
   140  		err = e.aggregate.runningMin.minmax(argVal, false, isFirstRow)
   141  
   142  	case aggFnMax:
   143  		err = e.aggregate.runningMax.minmax(argVal, true, isFirstRow)
   144  
   145  	default:
   146  		err = errInvalidAggregation
   147  	}
   148  
   149  	return err
   150  }
   151  
   152  func (e *AliasedExpression) aggregateRow(r Record, tableAlias string) error {
   153  	return e.Expression.aggregateRow(r, tableAlias)
   154  }
   155  
   156  func (e *Expression) aggregateRow(r Record, tableAlias string) error {
   157  	for _, ex := range e.And {
   158  		err := ex.aggregateRow(r, tableAlias)
   159  		if err != nil {
   160  			return err
   161  		}
   162  	}
   163  	return nil
   164  }
   165  
   166  func (e *ListExpr) aggregateRow(r Record, tableAlias string) error {
   167  	for _, ex := range e.Elements {
   168  		err := ex.aggregateRow(r, tableAlias)
   169  		if err != nil {
   170  			return err
   171  		}
   172  	}
   173  	return nil
   174  }
   175  
   176  func (e *AndCondition) aggregateRow(r Record, tableAlias string) error {
   177  	for _, ex := range e.Condition {
   178  		err := ex.aggregateRow(r, tableAlias)
   179  		if err != nil {
   180  			return err
   181  		}
   182  	}
   183  	return nil
   184  }
   185  
   186  func (e *Condition) aggregateRow(r Record, tableAlias string) error {
   187  	if e.Operand != nil {
   188  		return e.Operand.aggregateRow(r, tableAlias)
   189  	}
   190  	return e.Not.aggregateRow(r, tableAlias)
   191  }
   192  
   193  func (e *ConditionOperand) aggregateRow(r Record, tableAlias string) error {
   194  	err := e.Operand.aggregateRow(r, tableAlias)
   195  	if err != nil {
   196  		return err
   197  	}
   198  
   199  	if e.ConditionRHS == nil {
   200  		return nil
   201  	}
   202  
   203  	switch {
   204  	case e.ConditionRHS.Compare != nil:
   205  		return e.ConditionRHS.Compare.Operand.aggregateRow(r, tableAlias)
   206  	case e.ConditionRHS.Between != nil:
   207  		err = e.ConditionRHS.Between.Start.aggregateRow(r, tableAlias)
   208  		if err != nil {
   209  			return err
   210  		}
   211  		return e.ConditionRHS.Between.End.aggregateRow(r, tableAlias)
   212  	case e.ConditionRHS.In != nil:
   213  		elt := e.ConditionRHS.In.ListExpression
   214  		err = elt.aggregateRow(r, tableAlias)
   215  		if err != nil {
   216  			return err
   217  		}
   218  		return nil
   219  	case e.ConditionRHS.Like != nil:
   220  		err = e.ConditionRHS.Like.Pattern.aggregateRow(r, tableAlias)
   221  		if err != nil {
   222  			return err
   223  		}
   224  		return e.ConditionRHS.Like.EscapeChar.aggregateRow(r, tableAlias)
   225  	default:
   226  		return errInvalidASTNode
   227  	}
   228  }
   229  
   230  func (e *Operand) aggregateRow(r Record, tableAlias string) error {
   231  	err := e.Left.aggregateRow(r, tableAlias)
   232  	if err != nil {
   233  		return err
   234  	}
   235  	for _, rt := range e.Right {
   236  		err = rt.Right.aggregateRow(r, tableAlias)
   237  		if err != nil {
   238  			return err
   239  		}
   240  	}
   241  	return nil
   242  }
   243  
   244  func (e *MultOp) aggregateRow(r Record, tableAlias string) error {
   245  	err := e.Left.aggregateRow(r, tableAlias)
   246  	if err != nil {
   247  		return err
   248  	}
   249  	for _, rt := range e.Right {
   250  		err = rt.Right.aggregateRow(r, tableAlias)
   251  		if err != nil {
   252  			return err
   253  		}
   254  	}
   255  	return nil
   256  }
   257  
   258  func (e *UnaryTerm) aggregateRow(r Record, tableAlias string) error {
   259  	if e.Negated != nil {
   260  		return e.Negated.Term.aggregateRow(r, tableAlias)
   261  	}
   262  	return e.Primary.aggregateRow(r, tableAlias)
   263  }
   264  
   265  func (e *PrimaryTerm) aggregateRow(r Record, tableAlias string) error {
   266  	switch {
   267  	case e.ListExpr != nil:
   268  		return e.ListExpr.aggregateRow(r, tableAlias)
   269  	case e.SubExpression != nil:
   270  		return e.SubExpression.aggregateRow(r, tableAlias)
   271  	case e.FuncCall != nil:
   272  		return e.FuncCall.aggregateRow(r, tableAlias)
   273  	}
   274  	return nil
   275  }
   276  
   277  func (e *FuncExpr) aggregateRow(r Record, tableAlias string) error {
   278  	switch e.getFunctionName() {
   279  	case aggFnAvg, aggFnSum, aggFnMax, aggFnMin, aggFnCount:
   280  		return e.evalAggregationNode(r, tableAlias)
   281  	default:
   282  		// TODO: traverse arguments and call aggregateRow on
   283  		// them if they could be an ancestor of an
   284  		// aggregation.
   285  	}
   286  	return nil
   287  }
   288  
   289  // getAggregate() implementation for each AST node follows. This is
   290  // called after calling aggregateRow() on each input row, to calculate
   291  // the final aggregate result.
   292  
   293  func (e *FuncExpr) getAggregate() (*Value, error) {
   294  	switch e.getFunctionName() {
   295  	case aggFnCount:
   296  		return FromInt(e.aggregate.runningCount), nil
   297  
   298  	case aggFnAvg:
   299  		if e.aggregate.runningCount == 0 {
   300  			// No rows were seen by AVG.
   301  			return FromNull(), nil
   302  		}
   303  		err := e.aggregate.runningSum.arithOp(opDivide, FromInt(e.aggregate.runningCount))
   304  		return e.aggregate.runningSum, err
   305  
   306  	case aggFnMin:
   307  		if !e.aggregate.seen {
   308  			// No rows were seen by MIN
   309  			return FromNull(), nil
   310  		}
   311  		return e.aggregate.runningMin, nil
   312  
   313  	case aggFnMax:
   314  		if !e.aggregate.seen {
   315  			// No rows were seen by MAX
   316  			return FromNull(), nil
   317  		}
   318  		return e.aggregate.runningMax, nil
   319  
   320  	case aggFnSum:
   321  		// TODO: check if returning 0 when no rows were seen
   322  		// by SUM is expected behavior.
   323  		return e.aggregate.runningSum, nil
   324  
   325  	default:
   326  		// TODO:
   327  	}
   328  
   329  	return nil, errInvalidAggregation
   330  }