vitess.io/vitess@v0.16.2/go/vt/vtgate/planbuilder/grouping.go (about)

     1  /*
     2  Copyright 2020 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package planbuilder
    18  
    19  import (
    20  	"fmt"
    21  
    22  	"vitess.io/vitess/go/vt/sqlparser"
    23  	"vitess.io/vitess/go/vt/vterrors"
    24  	"vitess.io/vitess/go/vt/vtgate/engine"
    25  )
    26  
    27  func planGroupBy(pb *primitiveBuilder, input logicalPlan, groupBy sqlparser.GroupBy) (logicalPlan, error) {
    28  	if len(groupBy) == 0 {
    29  		// if we have no grouping declared, we only want to visit orderedAggregate
    30  		_, isOrdered := input.(*orderedAggregate)
    31  		if !isOrdered {
    32  			return input, nil
    33  		}
    34  	}
    35  
    36  	switch node := input.(type) {
    37  	case *mergeSort, *pulloutSubquery, *distinct:
    38  		inputs := node.Inputs()
    39  		input := inputs[0]
    40  
    41  		newInput, err := planGroupBy(pb, input, groupBy)
    42  		if err != nil {
    43  			return nil, err
    44  		}
    45  		inputs[0] = newInput
    46  		err = node.Rewrite(inputs...)
    47  		if err != nil {
    48  			return nil, err
    49  		}
    50  		return node, nil
    51  	case *route:
    52  		node.Select.(*sqlparser.Select).GroupBy = groupBy
    53  		return node, nil
    54  	case *orderedAggregate:
    55  		for _, expr := range groupBy {
    56  			colNumber := -1
    57  			switch e := expr.(type) {
    58  			case *sqlparser.ColName:
    59  				c := e.Metadata.(*column)
    60  				if c.Origin() == node {
    61  					return nil, vterrors.VT03005(sqlparser.String(e))
    62  				}
    63  				for i, rc := range node.resultColumns {
    64  					if rc.column == c {
    65  						colNumber = i
    66  						break
    67  					}
    68  				}
    69  				if colNumber == -1 {
    70  					return nil, vterrors.VT12001("in scatter query: GROUP BY column must reference column in SELECT list")
    71  				}
    72  			case *sqlparser.Literal:
    73  				num, err := ResultFromNumber(node.resultColumns, e, "group statement")
    74  				if err != nil {
    75  					return nil, err
    76  				}
    77  				colNumber = num
    78  			default:
    79  				return nil, vterrors.VT12001("in scatter query: only simple references are allowed")
    80  			}
    81  			node.groupByKeys = append(node.groupByKeys, &engine.GroupByParams{KeyCol: colNumber, WeightStringCol: -1, FromGroupBy: true})
    82  		}
    83  		// Append the distinct aggregate if any.
    84  		if node.extraDistinct != nil {
    85  			groupBy = append(groupBy, node.extraDistinct)
    86  		}
    87  
    88  		newInput, err := planGroupBy(pb, node.input, groupBy)
    89  		if err != nil {
    90  			return nil, err
    91  		}
    92  		node.input = newInput
    93  
    94  		return node, nil
    95  	}
    96  	return nil, vterrors.VT13001(fmt.Sprintf("unreachable %T.groupBy: ", input))
    97  }
    98  
    99  // planDistinct makes the output distinct
   100  func planDistinct(input logicalPlan) (logicalPlan, error) {
   101  	switch node := input.(type) {
   102  	case *route:
   103  		node.Select.MakeDistinct()
   104  		return node, nil
   105  	case *orderedAggregate:
   106  		for i, rc := range node.resultColumns {
   107  			// If the column origin is oa (and not the underlying route),
   108  			// it means that it's an aggregate function supplied by oa.
   109  			// So, the distinct 'operator' cannot be pushed down into the
   110  			// route.
   111  			if rc.column.Origin() == node {
   112  				return newDistinctV3(node), nil
   113  			}
   114  			node.groupByKeys = append(node.groupByKeys, &engine.GroupByParams{KeyCol: i, WeightStringCol: -1, FromGroupBy: false})
   115  		}
   116  		newInput, err := planDistinct(node.input)
   117  		if err != nil {
   118  			return nil, err
   119  		}
   120  		node.input = newInput
   121  		return node, nil
   122  
   123  	case *distinct:
   124  		return input, nil
   125  	}
   126  
   127  	return nil, vterrors.VT13001(fmt.Sprintf("unreachable %T.distinct", input))
   128  }