github.com/dolthub/go-mysql-server@v0.18.0/sql/analyzer/replace_count_star.go (about) 1 // Copyright 2020-2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package analyzer 16 17 import ( 18 "github.com/dolthub/go-mysql-server/sql/expression/function/aggregation" 19 "github.com/dolthub/go-mysql-server/sql/transform" 20 "github.com/dolthub/go-mysql-server/sql/types" 21 22 "github.com/dolthub/go-mysql-server/sql" 23 "github.com/dolthub/go-mysql-server/sql/expression" 24 "github.com/dolthub/go-mysql-server/sql/plan" 25 ) 26 27 // replaceCountStar replaces count(*) expressions with count(1) expressions, which are semantically equivalent and 28 // lets us prune all the unused columns from the target tables. 29 func replaceCountStar(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) { 30 if plan.IsDDLNode(n) { 31 return n, transform.SameTree, nil 32 } 33 34 return transform.Node(n, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) { 35 if agg, ok := n.(*plan.GroupBy); ok { 36 if len(agg.SelectedExprs) == 1 && len(agg.GroupByExprs) == 0 { 37 child := agg.SelectedExprs[0] 38 var cnt *aggregation.Count 39 name := "" 40 if alias, ok := child.(*expression.Alias); ok { 41 cnt, _ = alias.Child.(*aggregation.Count) 42 name = alias.Name() 43 } else { 44 cnt, _ = child.(*aggregation.Count) 45 name = child.String() 46 } 47 if cnt != nil { 48 switch cnt.Child.(type) { 49 case *expression.Star, *expression.Literal: 50 var rt *plan.ResolvedTable 51 switch c := agg.Child.(type) { 52 case *plan.ResolvedTable: 53 rt = c 54 case *plan.TableAlias: 55 if t, ok := c.Child.(*plan.ResolvedTable); ok { 56 rt = t 57 } 58 } 59 if rt != nil && !sql.IsKeyless(rt.Table.Schema()) { 60 if statsTable, ok := rt.Table.(sql.StatisticsTable); ok { 61 rowCnt, exact, err := statsTable.RowCount(ctx) 62 if err == nil && exact { 63 return plan.NewProject( 64 []sql.Expression{ 65 expression.NewAlias(name, expression.NewGetFieldWithTable(int(cnt.Id()), 0, types.Int64, rt.Database().Name(), statsTable.Name(), name, false)).WithId(cnt.Id()), 66 }, 67 plan.NewTableCount(name, rt.SqlDatabase, statsTable, rowCnt, cnt.Id()), 68 ), transform.NewTree, nil 69 } 70 } 71 } 72 } 73 } 74 } 75 } 76 77 return transform.NodeExprs(n, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) { 78 if count, ok := e.(*aggregation.Count); ok { 79 if _, ok := count.Child.(*expression.Star); ok { 80 count, err := count.WithChildren(expression.NewLiteral(int64(1), types.Int64)) 81 if err != nil { 82 return nil, transform.SameTree, err 83 } 84 return count, transform.NewTree, nil 85 } 86 } 87 88 return e, transform.SameTree, nil 89 }) 90 }) 91 }