github.com/dolthub/go-mysql-server@v0.18.0/sql/analyzer/aggregations.go (about) 1 // Copyright 2020-2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package analyzer 16 17 import ( 18 "github.com/dolthub/go-mysql-server/sql" 19 "github.com/dolthub/go-mysql-server/sql/expression" 20 "github.com/dolthub/go-mysql-server/sql/expression/function/aggregation" 21 "github.com/dolthub/go-mysql-server/sql/plan" 22 "github.com/dolthub/go-mysql-server/sql/transform" 23 "github.com/dolthub/go-mysql-server/sql/types" 24 ) 25 26 // flattenAggregationExpressions flattens any complex aggregate or window expressions in a GroupBy or Window node and 27 // adds a projection on top of the result. The child terms of any complex expressions get pushed down to become selected 28 // expressions in the GroupBy or Window, and then a new project node re-applies the original expression to the new 29 // schema of the flattened node. 30 // e.g. GroupBy(sum(a) + sum(b)) becomes project(sum(a) + sum(b), GroupBy(sum(a), sum(b)). 31 // e.g. Window(sum(a) + sum(b) over (partition by a)) becomes 32 // project(sum(a) + sum(b) over (partition by a), Window(sum(a), sum(b) over (partition by a))). 33 func flattenAggregationExpressions(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) { 34 span, ctx := ctx.Span("flatten_aggregation_exprs") 35 defer span.End() 36 37 if !n.Resolved() { 38 return n, transform.SameTree, nil 39 } 40 41 return transform.Node(n, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) { 42 switch n := n.(type) { 43 case *plan.Window: 44 if !hasHiddenAggregations(n.SelectExprs) && !hasHiddenWindows(n.SelectExprs) { 45 return n, transform.SameTree, nil 46 } 47 48 return flattenedWindow(ctx, scope, n.SelectExprs, n.Child) 49 case *plan.GroupBy: 50 if !hasHiddenAggregations(n.SelectedExprs) { 51 return n, transform.SameTree, nil 52 } 53 54 return flattenedGroupBy(ctx, scope, n.SelectedExprs, n.GroupByExprs, n.Child) 55 default: 56 return n, transform.SameTree, nil 57 } 58 }) 59 } 60 61 func flattenedGroupBy(ctx *sql.Context, scope *plan.Scope, projection, grouping []sql.Expression, child sql.Node) (sql.Node, transform.TreeIdentity, error) { 62 newProjection, newAggregates, allSame, err := replaceAggregatesWithGetFieldProjections(ctx, scope, projection) 63 if err != nil { 64 return nil, transform.SameTree, err 65 } 66 if allSame { 67 return nil, transform.SameTree, nil 68 } 69 return plan.NewProject( 70 newProjection, 71 plan.NewGroupBy(newAggregates, grouping, child), 72 ), transform.NewTree, nil 73 } 74 75 // replaceAggregatesWithGetFieldProjections inserts an indirection Projection 76 // between an aggregation and its scope output, resulting in two buckets of 77 // expressions: 78 // 1) Parent projection expressions. 79 // 2) Child aggregation expressions. 80 // 81 // A scope always returns a fixed number of columns, so the number of projection 82 // inputs and outputs must match. 83 // 84 // The aggregation must provide input dependencies for parent projections. 85 // Each parent expression can depend on zero or many aggregation expressions. 86 // There are two basic kinds of aggregation expressions: 87 // 1) Passthrough columns from scope input relation. 88 // 2) Synthesized columns from in-scope aggregation relation. 89 func replaceAggregatesWithGetFieldProjections(_ *sql.Context, scope *plan.Scope, projection []sql.Expression) (projections, aggregations []sql.Expression, identity transform.TreeIdentity, err error) { 90 var newProjection = make([]sql.Expression, len(projection)) 91 var newAggregates []sql.Expression 92 scopeLen := len(scope.Schema()) 93 aggPassthrough := make(map[string]struct{}) 94 /* every aggregation creates one pass-through reference into parent */ 95 for i, p := range projection { 96 e, same, err := transform.Expr(p, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) { 97 switch e := e.(type) { 98 case sql.Aggregation, sql.WindowAggregation: 99 newAggregates = append(newAggregates, e) 100 aggPassthrough[e.String()] = struct{}{} 101 typ := e.Type() 102 switch e.(type) { 103 case *aggregation.Sum, *aggregation.Avg: 104 typ = types.Float64 105 case *aggregation.Count: 106 typ = types.Int64 107 } 108 return expression.NewGetField(scopeLen+len(newAggregates)-1, typ, e.String(), e.IsNullable()), transform.NewTree, nil 109 default: 110 return e, transform.SameTree, nil 111 } 112 }) 113 if err != nil { 114 return nil, nil, transform.SameTree, err 115 } 116 117 if same { 118 var getField *expression.GetField 119 // add to plan.GroupBy.SelectedExprs iff expression has an expression.GetField 120 hasGetField := transform.InspectExpr(e, func(expr sql.Expression) bool { 121 gf, ok := expr.(*expression.GetField) 122 if ok { 123 getField = gf 124 } 125 return ok 126 }) 127 if hasGetField { 128 newAggregates = append(newAggregates, e) 129 name, source := getNameAndSource(e) 130 newProjection[i] = expression.NewGetFieldWithTable( 131 scopeLen+len(newAggregates)-1, int(getField.TableId()), e.Type(), getField.Database(), source, name, e.IsNullable(), 132 ) 133 } else { 134 newProjection[i] = e 135 } 136 } else { 137 newProjection[i] = e 138 transform.InspectExpr(e, func(e sql.Expression) bool { 139 // clean up projection dependency columns not synthesized by 140 // aggregation. 141 switch e := e.(type) { 142 case *expression.GetField: 143 if _, ok := aggPassthrough[e.Name()]; !ok { 144 // this is a column input to the projection that 145 // the aggregation parent has not passed-through. 146 // TODO: for functions without aggregate dependency, 147 // we just execute the function in the aggregation. 148 // why don't we do that for both? 149 newAggregates = append(newAggregates, e) 150 } 151 default: 152 } 153 return false 154 }) 155 } 156 } 157 158 return newProjection, newAggregates, transform.NewTree, nil 159 } 160 161 func flattenedWindow(ctx *sql.Context, scope *plan.Scope, projection []sql.Expression, child sql.Node) (sql.Node, transform.TreeIdentity, error) { 162 newProjection, newAggregates, allSame, err := replaceAggregatesWithGetFieldProjections(ctx, scope, projection) 163 if err != nil { 164 return nil, transform.SameTree, err 165 } 166 if allSame { 167 return nil, allSame, nil 168 } 169 return plan.NewProject( 170 newProjection, 171 plan.NewWindow(newAggregates, child), 172 ), transform.NewTree, nil 173 } 174 175 func getNameAndSource(e sql.Expression) (name, source string) { 176 if n, ok := e.(sql.Nameable); ok { 177 name = n.Name() 178 } else { 179 name = e.String() 180 } 181 182 if t, ok := e.(sql.Tableable); ok { 183 source = t.Table() 184 } 185 186 return 187 } 188 189 // hasHiddenAggregations returns whether any of the given expressions has a hidden aggregation. That is, an aggregation 190 // that is not at the root of the expression. 191 func hasHiddenAggregations(exprs []sql.Expression) bool { 192 for _, e := range exprs { 193 if containsHiddenAggregation(e) { 194 return true 195 } 196 } 197 return false 198 } 199 200 // containsHiddenAggregation returns whether the given expressions has a hidden aggregation. That is, an aggregation 201 // that is not at the root of the expression. 202 func containsHiddenAggregation(e sql.Expression) bool { 203 _, ok := e.(sql.Aggregation) 204 if ok { 205 return false 206 } 207 208 return containsAggregation(e) 209 } 210 211 // containsAggregation returns whether the expression given contains any sql.Aggregation terms. 212 func containsAggregation(e sql.Expression) bool { 213 var hasAgg bool 214 sql.Inspect(e, func(e sql.Expression) bool { 215 if _, ok := e.(sql.Aggregation); ok { 216 hasAgg = true 217 return false 218 } 219 return true 220 }) 221 return hasAgg 222 } 223 224 // hasHiddenWindows returns whether any of the given expression have a hidden window function. That is, a window 225 // function that is not at the root of the expression. 226 func hasHiddenWindows(exprs []sql.Expression) bool { 227 for _, e := range exprs { 228 if containsHiddenWindow(e) { 229 return true 230 } 231 } 232 return false 233 } 234 235 // containsHiddenWindow returns whether the given expression has a hidden window function. That is, a window function 236 // that is not at the root of the expression. 237 func containsHiddenWindow(e sql.Expression) bool { 238 _, ok := e.(sql.WindowAggregation) 239 if ok { 240 return false 241 } 242 243 return containsWindow(e) 244 } 245 246 // containsWindow returns whether the expression given contains any sql.WindowAggregation terms. 247 func containsWindow(e sql.Expression) bool { 248 var hasAgg bool 249 sql.Inspect(e, func(e sql.Expression) bool { 250 if _, ok := e.(sql.WindowAggregation); ok { 251 hasAgg = true 252 return false 253 } 254 return true 255 }) 256 return hasAgg 257 }