vitess.io/vitess@v0.16.2/go/vt/vtgate/planbuilder/ordered_aggregate.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package planbuilder
    18  
    19  import (
    20  	"fmt"
    21  	"strconv"
    22  	"strings"
    23  
    24  	"vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext"
    25  
    26  	"vitess.io/vitess/go/mysql/collations"
    27  
    28  	"vitess.io/vitess/go/sqltypes"
    29  
    30  	"vitess.io/vitess/go/vt/vterrors"
    31  
    32  	"vitess.io/vitess/go/vt/sqlparser"
    33  	"vitess.io/vitess/go/vt/vtgate/engine"
    34  )
    35  
    36  var _ logicalPlan = (*orderedAggregate)(nil)
    37  
    38  // orderedAggregate is the logicalPlan for engine.OrderedAggregate.
    39  // This gets built if there are aggregations on a SelectScatter
    40  // route. The primitive requests the underlying route to order
    41  // the results by the grouping columns. This will allow the
    42  // engine code to aggregate the results as they come.
    43  // For example: 'select col1, col2, count(*) from t group by col1, col2'
    44  // will be sent to the scatter route as:
    45  // 'select col1, col2, count(*) from t group by col1, col2 order by col1, col2`
    46  // The orderAggregate primitive built for this will be:
    47  //
    48  //	&engine.OrderedAggregate {
    49  //	  // Aggregates has one column. It computes the count
    50  //	  // using column 2 of the underlying route.
    51  //	  Aggregates: []AggregateParams{{
    52  //	    Opcode: AggregateCount,
    53  //	    Col: 2,
    54  //	  }},
    55  //
    56  //	  // Keys has the two group by values for col1 and col2.
    57  //	  // The column numbers are from the underlying route.
    58  //	  // These values will be used to perform the grouping
    59  //	  // of the ordered results as they come from the underlying
    60  //	  // route.
    61  //	  Keys: []int{0, 1},
    62  //	  Input: (Scatter Route with the order by request),
    63  //	}
    64  type orderedAggregate struct {
    65  	resultsBuilder
    66  	extraDistinct *sqlparser.ColName
    67  
    68  	// preProcess is true if one of the aggregates needs preprocessing.
    69  	preProcess bool
    70  
    71  	aggrOnEngine bool
    72  
    73  	// aggregates specifies the aggregation parameters for each
    74  	// aggregation function: function opcode and input column number.
    75  	aggregates []*engine.AggregateParams
    76  
    77  	// groupByKeys specifies the input values that must be used for
    78  	// the aggregation key.
    79  	groupByKeys []*engine.GroupByParams
    80  
    81  	truncateColumnCount int
    82  }
    83  
    84  // checkAggregates analyzes the select expression for aggregates. If it determines
    85  // that a primitive is needed to handle the aggregation, it builds an orderedAggregate
    86  // primitive and returns it. It returns a groupByHandler if there is aggregation it
    87  // can handle.
    88  func (pb *primitiveBuilder) checkAggregates(sel *sqlparser.Select) error {
    89  	rb, isRoute := pb.plan.(*route)
    90  	if isRoute && rb.isSingleShard() {
    91  		// since we can push down all of the aggregation to the route,
    92  		// we don't need to do anything else here
    93  		return nil
    94  	}
    95  
    96  	// Check if we can allow aggregates.
    97  	hasAggregates := sqlparser.ContainsAggregation(sel.SelectExprs) || len(sel.GroupBy) > 0
    98  	if !hasAggregates && !sel.Distinct {
    99  		return nil
   100  	}
   101  
   102  	// The query has aggregates. We can proceed only
   103  	// if the underlying primitive is a route because
   104  	// we need the ability to push down group by and
   105  	// order by clauses.
   106  	if !isRoute {
   107  		if hasAggregates {
   108  			return vterrors.VT12001("cross-shard query with aggregates")
   109  		}
   110  		pb.plan = newDistinctV3(pb.plan)
   111  		return nil
   112  	}
   113  
   114  	// If there is a distinct clause, we can check the select list
   115  	// to see if it has a unique vindex reference. For example,
   116  	// if the query was 'select distinct id, col from t' (with id
   117  	// as a unique vindex), then the distinct operation can be
   118  	// safely pushed down because the unique vindex guarantees
   119  	// that each id can only be in a single shard. Without the
   120  	// unique vindex property, the id could come from multiple
   121  	// shards, which will require us to perform the grouping
   122  	// at the vtgate level.
   123  	if sel.Distinct {
   124  		for _, selectExpr := range sel.SelectExprs {
   125  			switch selectExpr := selectExpr.(type) {
   126  			case *sqlparser.AliasedExpr:
   127  				vindex := pb.st.Vindex(selectExpr.Expr, rb)
   128  				if vindex != nil && vindex.IsUnique() {
   129  					return nil
   130  				}
   131  			}
   132  		}
   133  	}
   134  
   135  	// The group by clause could also reference a unique vindex. The above
   136  	// example could itself have been written as
   137  	// 'select id, col from t group by id, col', or a query could be like
   138  	// 'select id, count(*) from t group by id'. In the above cases,
   139  	// the grouping can be done at the shard level, which allows the entire query
   140  	// to be pushed down. In order to perform this analysis, we're going to look
   141  	// ahead at the group by clause to see if it references a unique vindex.
   142  	if pb.groupByHasUniqueVindex(sel, rb) {
   143  		return nil
   144  	}
   145  
   146  	// We need an aggregator primitive.
   147  	oa := &orderedAggregate{}
   148  	oa.resultsBuilder = newResultsBuilder(rb, oa)
   149  	pb.plan = oa
   150  	pb.plan.Reorder(0)
   151  	return nil
   152  }
   153  
   154  // groupbyHasUniqueVindex looks ahead at the group by expression to see if
   155  // it references a unique vindex.
   156  //
   157  // The vitess group by rules are different from MySQL because it's not possible
   158  // to match the MySQL behavior without knowing the schema. For example:
   159  // 'select id as val from t group by val' will have different interpretations
   160  // under MySQL depending on whether t has a val column or not.
   161  // In vitess, we always assume that 'val' references 'id'. This is achieved
   162  // by the symbol table resolving against the select list before searching
   163  // the tables.
   164  //
   165  // In order to look ahead, we have to overcome the chicken-and-egg problem:
   166  // group by needs the select aliases to be built. Select aliases are built
   167  // on push-down. But push-down decision depends on whether group by expressions
   168  // reference a vindex.
   169  // To overcome this, the look-ahead has to perform a search that matches
   170  // the group by analyzer. The flow is similar to oa.PushGroupBy, except that
   171  // we don't search the ResultColumns because they're not created yet. Also,
   172  // error conditions are treated as no match for simplicity; They will be
   173  // subsequently caught downstream.
   174  func (pb *primitiveBuilder) groupByHasUniqueVindex(sel *sqlparser.Select, rb *route) bool {
   175  	for _, expr := range sel.GroupBy {
   176  		var matchedExpr sqlparser.Expr
   177  		switch node := expr.(type) {
   178  		case *sqlparser.ColName:
   179  			if expr := findAlias(node, sel.SelectExprs); expr != nil {
   180  				matchedExpr = expr
   181  			} else {
   182  				matchedExpr = node
   183  			}
   184  		case *sqlparser.Literal:
   185  			if node.Type != sqlparser.IntVal {
   186  				continue
   187  			}
   188  			num, err := strconv.ParseInt(string(node.Val), 0, 64)
   189  			if err != nil {
   190  				continue
   191  			}
   192  			if num < 1 || num > int64(len(sel.SelectExprs)) {
   193  				continue
   194  			}
   195  			expr, ok := sel.SelectExprs[num-1].(*sqlparser.AliasedExpr)
   196  			if !ok {
   197  				continue
   198  			}
   199  			matchedExpr = expr.Expr
   200  		default:
   201  			continue
   202  		}
   203  		vindex := pb.st.Vindex(matchedExpr, rb)
   204  		if vindex != nil && vindex.IsUnique() {
   205  			return true
   206  		}
   207  	}
   208  	return false
   209  }
   210  
   211  func findAlias(colname *sqlparser.ColName, selects sqlparser.SelectExprs) sqlparser.Expr {
   212  	// Qualified column names cannot match an (unqualified) alias.
   213  	if !colname.Qualifier.IsEmpty() {
   214  		return nil
   215  	}
   216  	// See if this references an alias.
   217  	for _, selectExpr := range selects {
   218  		selectExpr, ok := selectExpr.(*sqlparser.AliasedExpr)
   219  		if !ok {
   220  			continue
   221  		}
   222  		if colname.Name.Equal(selectExpr.As) {
   223  			return selectExpr.Expr
   224  		}
   225  	}
   226  	return nil
   227  }
   228  
   229  // Primitive implements the logicalPlan interface
   230  func (oa *orderedAggregate) Primitive() engine.Primitive {
   231  	colls := map[int]collations.ID{}
   232  	for _, key := range oa.aggregates {
   233  		if key.CollationID != collations.Unknown {
   234  			colls[key.KeyCol] = key.CollationID
   235  		}
   236  	}
   237  	for _, key := range oa.groupByKeys {
   238  		if key.CollationID != collations.Unknown {
   239  			colls[key.KeyCol] = key.CollationID
   240  		}
   241  	}
   242  
   243  	input := oa.input.Primitive()
   244  	if len(oa.groupByKeys) == 0 {
   245  		return &engine.ScalarAggregate{
   246  			PreProcess:          oa.preProcess,
   247  			AggrOnEngine:        oa.aggrOnEngine,
   248  			Aggregates:          oa.aggregates,
   249  			TruncateColumnCount: oa.truncateColumnCount,
   250  			Collations:          colls,
   251  			Input:               input,
   252  		}
   253  	}
   254  
   255  	return &engine.OrderedAggregate{
   256  		PreProcess:          oa.preProcess,
   257  		AggrOnEngine:        oa.aggrOnEngine,
   258  		Aggregates:          oa.aggregates,
   259  		GroupByKeys:         oa.groupByKeys,
   260  		TruncateColumnCount: oa.truncateColumnCount,
   261  		Collations:          colls,
   262  		Input:               input,
   263  	}
   264  }
   265  
   266  func (oa *orderedAggregate) pushAggr(pb *primitiveBuilder, expr *sqlparser.AliasedExpr, origin logicalPlan) (rc *resultColumn, colNumber int, err error) {
   267  	aggrFunc, _ := expr.Expr.(sqlparser.AggrFunc)
   268  	origOpcode := engine.SupportedAggregates[strings.ToLower(aggrFunc.AggrName())]
   269  	opcode := origOpcode
   270  	if aggrFunc.GetArgs() != nil &&
   271  		len(aggrFunc.GetArgs()) != 1 {
   272  		return nil, 0, vterrors.VT12001(fmt.Sprintf("only one expression is allowed inside aggregates: %s", sqlparser.String(expr)))
   273  	}
   274  
   275  	handleDistinct, innerAliased, err := oa.needDistinctHandling(pb, expr, opcode)
   276  	if err != nil {
   277  		return nil, 0, err
   278  	}
   279  	if handleDistinct {
   280  		if oa.extraDistinct != nil {
   281  			return nil, 0, vterrors.VT12001(fmt.Sprintf("only one DISTINCT aggregation allowed in a SELECT: %s", sqlparser.String(expr)))
   282  		}
   283  		// Push the expression that's inside the aggregate.
   284  		// The column will eventually get added to the group by and order by clauses.
   285  		newBuilder, _, innerCol, err := planProjection(pb, oa.input, innerAliased, origin)
   286  		if err != nil {
   287  			return nil, 0, err
   288  		}
   289  		pb.plan = newBuilder
   290  		col, err := BuildColName(oa.input.ResultColumns(), innerCol)
   291  		if err != nil {
   292  			return nil, 0, err
   293  		}
   294  		oa.extraDistinct = col
   295  		oa.preProcess = true
   296  		switch opcode {
   297  		case engine.AggregateCount:
   298  			opcode = engine.AggregateCountDistinct
   299  		case engine.AggregateSum:
   300  			opcode = engine.AggregateSumDistinct
   301  		}
   302  		oa.aggregates = append(oa.aggregates, &engine.AggregateParams{
   303  			Opcode:     opcode,
   304  			Col:        innerCol,
   305  			Alias:      expr.ColumnName(),
   306  			OrigOpcode: origOpcode,
   307  		})
   308  	} else {
   309  		newBuilder, _, innerCol, err := planProjection(pb, oa.input, expr, origin)
   310  		if err != nil {
   311  			return nil, 0, err
   312  		}
   313  		pb.plan = newBuilder
   314  		oa.aggregates = append(oa.aggregates, &engine.AggregateParams{
   315  			Opcode:     opcode,
   316  			Col:        innerCol,
   317  			OrigOpcode: origOpcode,
   318  		})
   319  	}
   320  
   321  	// Build a new rc with oa as origin because it's semantically different
   322  	// from the expression we pushed down.
   323  	rc = newResultColumn(expr, oa)
   324  	oa.resultColumns = append(oa.resultColumns, rc)
   325  	return rc, len(oa.resultColumns) - 1, nil
   326  }
   327  
   328  // needDistinctHandling returns true if oa needs to handle the distinct clause.
   329  // If true, it will also return the aliased expression that needs to be pushed
   330  // down into the underlying route.
   331  func (oa *orderedAggregate) needDistinctHandling(pb *primitiveBuilder, expr *sqlparser.AliasedExpr, opcode engine.AggregateOpcode) (bool, *sqlparser.AliasedExpr, error) {
   332  	var innerAliased *sqlparser.AliasedExpr
   333  	aggr, ok := expr.Expr.(sqlparser.AggrFunc)
   334  
   335  	if !ok {
   336  		return false, nil, vterrors.VT03012(sqlparser.String(expr))
   337  	}
   338  
   339  	if !aggr.IsDistinct() {
   340  		return false, nil, nil
   341  	}
   342  	if opcode != engine.AggregateCount && opcode != engine.AggregateSum && opcode != engine.AggregateCountStar {
   343  		return false, nil, nil
   344  	}
   345  
   346  	innerAliased = &sqlparser.AliasedExpr{Expr: aggr.GetArg()}
   347  
   348  	rb, ok := oa.input.(*route)
   349  	if !ok {
   350  		// Unreachable
   351  		return true, innerAliased, nil
   352  	}
   353  	vindex := pb.st.Vindex(innerAliased.Expr, rb)
   354  	if vindex != nil && vindex.IsUnique() {
   355  		return false, nil, nil
   356  	}
   357  	return true, innerAliased, nil
   358  }
   359  
   360  // Wireup implements the logicalPlan interface
   361  // If text columns are detected in the keys, then the function modifies
   362  // the primitive to pull a corresponding weight_string from mysql and
   363  // compare those instead. This is because we currently don't have the
   364  // ability to mimic mysql's collation behavior.
   365  func (oa *orderedAggregate) Wireup(plan logicalPlan, jt *jointab) error {
   366  	for i, gbk := range oa.groupByKeys {
   367  		rc := oa.resultColumns[gbk.KeyCol]
   368  		if sqltypes.IsText(rc.column.typ) {
   369  			weightcolNumber, err := oa.input.SupplyWeightString(gbk.KeyCol, gbk.FromGroupBy)
   370  			if err != nil {
   371  				_, isUnsupportedErr := err.(UnsupportedSupplyWeightString)
   372  				if isUnsupportedErr {
   373  					continue
   374  				}
   375  				return err
   376  			}
   377  			oa.weightStrings[rc] = weightcolNumber
   378  			oa.groupByKeys[i].WeightStringCol = weightcolNumber
   379  			oa.groupByKeys[i].KeyCol = weightcolNumber
   380  			oa.truncateColumnCount = len(oa.resultColumns)
   381  		}
   382  	}
   383  	for _, key := range oa.aggregates {
   384  		switch key.Opcode {
   385  		case engine.AggregateCount:
   386  			if key.Alias == "" {
   387  				key.Alias = key.Opcode.String()
   388  			}
   389  			key.Opcode = engine.AggregateSum
   390  		}
   391  	}
   392  
   393  	return oa.input.Wireup(plan, jt)
   394  }
   395  
   396  func (oa *orderedAggregate) WireupGen4(ctx *plancontext.PlanningContext) error {
   397  	return oa.input.WireupGen4(ctx)
   398  }
   399  
   400  // OutputColumns implements the logicalPlan interface
   401  func (oa *orderedAggregate) OutputColumns() []sqlparser.SelectExpr {
   402  	outputCols := sqlparser.CloneSelectExprs(oa.input.OutputColumns())
   403  	for _, aggr := range oa.aggregates {
   404  		outputCols[aggr.Col] = &sqlparser.AliasedExpr{Expr: aggr.Expr, As: sqlparser.NewIdentifierCI(aggr.Alias)}
   405  	}
   406  	if oa.truncateColumnCount > 0 {
   407  		return outputCols[:oa.truncateColumnCount]
   408  	}
   409  	return outputCols
   410  }
   411  
   412  // SetTruncateColumnCount sets the truncate column count.
   413  func (oa *orderedAggregate) SetTruncateColumnCount(count int) {
   414  	oa.truncateColumnCount = count
   415  }