github.com/dolthub/go-mysql-server@v0.18.0/sql/analyzer/aggregations.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package analyzer
    16  
    17  import (
    18  	"github.com/dolthub/go-mysql-server/sql"
    19  	"github.com/dolthub/go-mysql-server/sql/expression"
    20  	"github.com/dolthub/go-mysql-server/sql/expression/function/aggregation"
    21  	"github.com/dolthub/go-mysql-server/sql/plan"
    22  	"github.com/dolthub/go-mysql-server/sql/transform"
    23  	"github.com/dolthub/go-mysql-server/sql/types"
    24  )
    25  
    26  // flattenAggregationExpressions flattens any complex aggregate or window expressions in a GroupBy or Window node and
    27  // adds a projection on top of the result. The child terms of any complex expressions get pushed down to become selected
    28  // expressions in the GroupBy or Window, and then a new project node re-applies the original expression to the new
    29  // schema of the flattened node.
    30  // e.g. GroupBy(sum(a) + sum(b)) becomes project(sum(a) + sum(b), GroupBy(sum(a), sum(b)).
    31  // e.g. Window(sum(a) + sum(b) over (partition by a)) becomes
    32  // project(sum(a) + sum(b) over (partition by a), Window(sum(a), sum(b) over (partition by a))).
    33  func flattenAggregationExpressions(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) {
    34  	span, ctx := ctx.Span("flatten_aggregation_exprs")
    35  	defer span.End()
    36  
    37  	if !n.Resolved() {
    38  		return n, transform.SameTree, nil
    39  	}
    40  
    41  	return transform.Node(n, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
    42  		switch n := n.(type) {
    43  		case *plan.Window:
    44  			if !hasHiddenAggregations(n.SelectExprs) && !hasHiddenWindows(n.SelectExprs) {
    45  				return n, transform.SameTree, nil
    46  			}
    47  
    48  			return flattenedWindow(ctx, scope, n.SelectExprs, n.Child)
    49  		case *plan.GroupBy:
    50  			if !hasHiddenAggregations(n.SelectedExprs) {
    51  				return n, transform.SameTree, nil
    52  			}
    53  
    54  			return flattenedGroupBy(ctx, scope, n.SelectedExprs, n.GroupByExprs, n.Child)
    55  		default:
    56  			return n, transform.SameTree, nil
    57  		}
    58  	})
    59  }
    60  
    61  func flattenedGroupBy(ctx *sql.Context, scope *plan.Scope, projection, grouping []sql.Expression, child sql.Node) (sql.Node, transform.TreeIdentity, error) {
    62  	newProjection, newAggregates, allSame, err := replaceAggregatesWithGetFieldProjections(ctx, scope, projection)
    63  	if err != nil {
    64  		return nil, transform.SameTree, err
    65  	}
    66  	if allSame {
    67  		return nil, transform.SameTree, nil
    68  	}
    69  	return plan.NewProject(
    70  		newProjection,
    71  		plan.NewGroupBy(newAggregates, grouping, child),
    72  	), transform.NewTree, nil
    73  }
    74  
    75  // replaceAggregatesWithGetFieldProjections inserts an indirection Projection
    76  // between an aggregation and its scope output, resulting in two buckets of
    77  // expressions:
    78  // 1) Parent projection expressions.
    79  // 2) Child aggregation expressions.
    80  //
    81  // A scope always returns a fixed number of columns, so the number of projection
    82  // inputs and outputs must match.
    83  //
    84  // The aggregation must provide input dependencies for parent projections.
    85  // Each parent expression can depend on zero or many aggregation expressions.
    86  // There are two basic kinds of aggregation expressions:
    87  // 1) Passthrough columns from scope input relation.
    88  // 2) Synthesized columns from in-scope aggregation relation.
    89  func replaceAggregatesWithGetFieldProjections(_ *sql.Context, scope *plan.Scope, projection []sql.Expression) (projections, aggregations []sql.Expression, identity transform.TreeIdentity, err error) {
    90  	var newProjection = make([]sql.Expression, len(projection))
    91  	var newAggregates []sql.Expression
    92  	scopeLen := len(scope.Schema())
    93  	aggPassthrough := make(map[string]struct{})
    94  	/* every aggregation creates one pass-through reference into parent */
    95  	for i, p := range projection {
    96  		e, same, err := transform.Expr(p, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
    97  			switch e := e.(type) {
    98  			case sql.Aggregation, sql.WindowAggregation:
    99  				newAggregates = append(newAggregates, e)
   100  				aggPassthrough[e.String()] = struct{}{}
   101  				typ := e.Type()
   102  				switch e.(type) {
   103  				case *aggregation.Sum, *aggregation.Avg:
   104  					typ = types.Float64
   105  				case *aggregation.Count:
   106  					typ = types.Int64
   107  				}
   108  				return expression.NewGetField(scopeLen+len(newAggregates)-1, typ, e.String(), e.IsNullable()), transform.NewTree, nil
   109  			default:
   110  				return e, transform.SameTree, nil
   111  			}
   112  		})
   113  		if err != nil {
   114  			return nil, nil, transform.SameTree, err
   115  		}
   116  
   117  		if same {
   118  			var getField *expression.GetField
   119  			// add to plan.GroupBy.SelectedExprs iff expression has an expression.GetField
   120  			hasGetField := transform.InspectExpr(e, func(expr sql.Expression) bool {
   121  				gf, ok := expr.(*expression.GetField)
   122  				if ok {
   123  					getField = gf
   124  				}
   125  				return ok
   126  			})
   127  			if hasGetField {
   128  				newAggregates = append(newAggregates, e)
   129  				name, source := getNameAndSource(e)
   130  				newProjection[i] = expression.NewGetFieldWithTable(
   131  					scopeLen+len(newAggregates)-1, int(getField.TableId()), e.Type(), getField.Database(), source, name, e.IsNullable(),
   132  				)
   133  			} else {
   134  				newProjection[i] = e
   135  			}
   136  		} else {
   137  			newProjection[i] = e
   138  			transform.InspectExpr(e, func(e sql.Expression) bool {
   139  				// clean up projection dependency columns not synthesized by
   140  				// aggregation.
   141  				switch e := e.(type) {
   142  				case *expression.GetField:
   143  					if _, ok := aggPassthrough[e.Name()]; !ok {
   144  						// this is a column input to the projection that
   145  						// the aggregation parent has not passed-through.
   146  						// TODO: for functions without aggregate dependency,
   147  						// we just execute the function in the aggregation.
   148  						// why don't we do that for both?
   149  						newAggregates = append(newAggregates, e)
   150  					}
   151  				default:
   152  				}
   153  				return false
   154  			})
   155  		}
   156  	}
   157  
   158  	return newProjection, newAggregates, transform.NewTree, nil
   159  }
   160  
   161  func flattenedWindow(ctx *sql.Context, scope *plan.Scope, projection []sql.Expression, child sql.Node) (sql.Node, transform.TreeIdentity, error) {
   162  	newProjection, newAggregates, allSame, err := replaceAggregatesWithGetFieldProjections(ctx, scope, projection)
   163  	if err != nil {
   164  		return nil, transform.SameTree, err
   165  	}
   166  	if allSame {
   167  		return nil, allSame, nil
   168  	}
   169  	return plan.NewProject(
   170  		newProjection,
   171  		plan.NewWindow(newAggregates, child),
   172  	), transform.NewTree, nil
   173  }
   174  
   175  func getNameAndSource(e sql.Expression) (name, source string) {
   176  	if n, ok := e.(sql.Nameable); ok {
   177  		name = n.Name()
   178  	} else {
   179  		name = e.String()
   180  	}
   181  
   182  	if t, ok := e.(sql.Tableable); ok {
   183  		source = t.Table()
   184  	}
   185  
   186  	return
   187  }
   188  
   189  // hasHiddenAggregations returns whether any of the given expressions has a hidden aggregation. That is, an aggregation
   190  // that is not at the root of the expression.
   191  func hasHiddenAggregations(exprs []sql.Expression) bool {
   192  	for _, e := range exprs {
   193  		if containsHiddenAggregation(e) {
   194  			return true
   195  		}
   196  	}
   197  	return false
   198  }
   199  
   200  // containsHiddenAggregation returns whether the given expressions has a hidden aggregation. That is, an aggregation
   201  // that is not at the root of the expression.
   202  func containsHiddenAggregation(e sql.Expression) bool {
   203  	_, ok := e.(sql.Aggregation)
   204  	if ok {
   205  		return false
   206  	}
   207  
   208  	return containsAggregation(e)
   209  }
   210  
   211  // containsAggregation returns whether the expression given contains any sql.Aggregation terms.
   212  func containsAggregation(e sql.Expression) bool {
   213  	var hasAgg bool
   214  	sql.Inspect(e, func(e sql.Expression) bool {
   215  		if _, ok := e.(sql.Aggregation); ok {
   216  			hasAgg = true
   217  			return false
   218  		}
   219  		return true
   220  	})
   221  	return hasAgg
   222  }
   223  
   224  // hasHiddenWindows returns whether any of the given expression have a hidden window function. That is, a window
   225  // function that is not at the root of the expression.
   226  func hasHiddenWindows(exprs []sql.Expression) bool {
   227  	for _, e := range exprs {
   228  		if containsHiddenWindow(e) {
   229  			return true
   230  		}
   231  	}
   232  	return false
   233  }
   234  
   235  // containsHiddenWindow returns whether the given expression has a hidden window function. That is, a window function
   236  // that is not at the root of the expression.
   237  func containsHiddenWindow(e sql.Expression) bool {
   238  	_, ok := e.(sql.WindowAggregation)
   239  	if ok {
   240  		return false
   241  	}
   242  
   243  	return containsWindow(e)
   244  }
   245  
   246  // containsWindow returns whether the expression given contains any sql.WindowAggregation terms.
   247  func containsWindow(e sql.Expression) bool {
   248  	var hasAgg bool
   249  	sql.Inspect(e, func(e sql.Expression) bool {
   250  		if _, ok := e.(sql.WindowAggregation); ok {
   251  			hasAgg = true
   252  			return false
   253  		}
   254  		return true
   255  	})
   256  	return hasAgg
   257  }