vitess.io/vitess@v0.16.2/go/vt/vtgate/engine/scalar_aggregation.go (about)

     1  /*
     2  Copyright 2022 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package engine
    18  
    19  import (
    20  	"context"
    21  	"sync"
    22  
    23  	"vitess.io/vitess/go/mysql/collations"
    24  	"vitess.io/vitess/go/sqltypes"
    25  	querypb "vitess.io/vitess/go/vt/proto/query"
    26  	"vitess.io/vitess/go/vt/proto/vtrpc"
    27  	"vitess.io/vitess/go/vt/vterrors"
    28  )
    29  
    30  var _ Primitive = (*ScalarAggregate)(nil)
    31  
    32  // ScalarAggregate is a primitive used to do aggregations without grouping keys
    33  type ScalarAggregate struct {
    34  	// PreProcess is true if one of the aggregates needs preprocessing.
    35  	PreProcess bool `json:",omitempty"`
    36  
    37  	AggrOnEngine bool
    38  
    39  	// Aggregates specifies the aggregation parameters for each
    40  	// aggregation function: function opcode and input column number.
    41  	Aggregates []*AggregateParams
    42  
    43  	// TruncateColumnCount specifies the number of columns to return
    44  	// in the final result. Rest of the columns are truncated
    45  	// from the result received. If 0, no truncation happens.
    46  	TruncateColumnCount int `json:",omitempty"`
    47  
    48  	// Collations stores the collation ID per column offset.
    49  	// It is used for grouping keys and distinct aggregate functions
    50  	Collations map[int]collations.ID
    51  
    52  	// Input is the primitive that will feed into this Primitive.
    53  	Input Primitive
    54  }
    55  
    56  // RouteType implements the Primitive interface
    57  func (sa *ScalarAggregate) RouteType() string {
    58  	return sa.Input.RouteType()
    59  }
    60  
    61  // GetKeyspaceName implements the Primitive interface
    62  func (sa *ScalarAggregate) GetKeyspaceName() string {
    63  	return sa.Input.GetKeyspaceName()
    64  
    65  }
    66  
    67  // GetTableName implements the Primitive interface
    68  func (sa *ScalarAggregate) GetTableName() string {
    69  	return sa.Input.GetTableName()
    70  }
    71  
    72  // GetFields implements the Primitive interface
    73  func (sa *ScalarAggregate) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) {
    74  	qr, err := sa.Input.GetFields(ctx, vcursor, bindVars)
    75  	if err != nil {
    76  		return nil, err
    77  	}
    78  	qr = &sqltypes.Result{Fields: convertFields(qr.Fields, sa.PreProcess, sa.Aggregates, sa.AggrOnEngine)}
    79  	return qr.Truncate(sa.TruncateColumnCount), nil
    80  }
    81  
    82  // NeedsTransaction implements the Primitive interface
    83  func (sa *ScalarAggregate) NeedsTransaction() bool {
    84  	return sa.Input.NeedsTransaction()
    85  }
    86  
    87  // TryExecute implements the Primitive interface
    88  func (sa *ScalarAggregate) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) {
    89  	result, err := vcursor.ExecutePrimitive(ctx, sa.Input, bindVars, wantfields)
    90  	if err != nil {
    91  		return nil, err
    92  	}
    93  	out := &sqltypes.Result{
    94  		Fields: convertFields(result.Fields, sa.PreProcess, sa.Aggregates, sa.AggrOnEngine),
    95  	}
    96  
    97  	var resultRow []sqltypes.Value
    98  	var curDistincts []sqltypes.Value
    99  	for _, row := range result.Rows {
   100  		if resultRow == nil {
   101  			resultRow, curDistincts = convertRow(row, sa.PreProcess, sa.Aggregates, sa.AggrOnEngine)
   102  			continue
   103  		}
   104  		resultRow, curDistincts, err = merge(result.Fields, resultRow, row, curDistincts, sa.Collations, sa.Aggregates)
   105  		if err != nil {
   106  			return nil, err
   107  		}
   108  	}
   109  
   110  	if resultRow == nil {
   111  		// When doing aggregation without grouping keys, we need to produce a single row containing zero-value for the
   112  		// different aggregation functions
   113  		resultRow, err = sa.createEmptyRow()
   114  		if err != nil {
   115  			return nil, err
   116  		}
   117  	} else {
   118  		resultRow, err = convertFinal(resultRow, sa.Aggregates)
   119  		if err != nil {
   120  			return nil, err
   121  		}
   122  	}
   123  
   124  	out.Rows = [][]sqltypes.Value{resultRow}
   125  	return out.Truncate(sa.TruncateColumnCount), nil
   126  }
   127  
   128  // TryStreamExecute implements the Primitive interface
   129  func (sa *ScalarAggregate) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
   130  	cb := func(qr *sqltypes.Result) error {
   131  		return callback(qr.Truncate(sa.TruncateColumnCount))
   132  	}
   133  	var current []sqltypes.Value
   134  	var curDistincts []sqltypes.Value
   135  	var fields []*querypb.Field
   136  	fieldsSent := false
   137  	var mu sync.Mutex
   138  
   139  	err := vcursor.StreamExecutePrimitive(ctx, sa.Input, bindVars, wantfields, func(result *sqltypes.Result) error {
   140  		// as the underlying primitive call is not sync
   141  		// and here scalar aggregate is using shared variables we have to sync the callback
   142  		// for correct aggregation.
   143  		mu.Lock()
   144  		defer mu.Unlock()
   145  		if len(result.Fields) != 0 && !fieldsSent {
   146  			fields = convertFields(result.Fields, sa.PreProcess, sa.Aggregates, sa.AggrOnEngine)
   147  			if err := cb(&sqltypes.Result{Fields: fields}); err != nil {
   148  				return err
   149  			}
   150  			fieldsSent = true
   151  		}
   152  
   153  		// this code is very similar to the TryExecute method
   154  		for _, row := range result.Rows {
   155  			if current == nil {
   156  				current, curDistincts = convertRow(row, sa.PreProcess, sa.Aggregates, sa.AggrOnEngine)
   157  				continue
   158  			}
   159  			var err error
   160  			current, curDistincts, err = merge(fields, current, row, curDistincts, sa.Collations, sa.Aggregates)
   161  			if err != nil {
   162  				return err
   163  			}
   164  		}
   165  		return nil
   166  	})
   167  	if err != nil {
   168  		return err
   169  	}
   170  
   171  	if current == nil {
   172  		// When doing aggregation without grouping keys, we need to produce a single row containing zero-value for the
   173  		// different aggregation functions
   174  		current, err = sa.createEmptyRow()
   175  		if err != nil {
   176  			return err
   177  		}
   178  	} else {
   179  		current, err = convertFinal(current, sa.Aggregates)
   180  		if err != nil {
   181  			return err
   182  		}
   183  	}
   184  
   185  	return cb(&sqltypes.Result{Rows: [][]sqltypes.Value{current}})
   186  }
   187  
   188  // creates the empty row for the case when we are missing grouping keys and have empty input table
   189  func (sa *ScalarAggregate) createEmptyRow() ([]sqltypes.Value, error) {
   190  	out := make([]sqltypes.Value, len(sa.Aggregates))
   191  	for i, aggr := range sa.Aggregates {
   192  		op := aggr.Opcode
   193  		if aggr.OrigOpcode != AggregateUnassigned {
   194  			op = aggr.OrigOpcode
   195  		}
   196  		value, err := createEmptyValueFor(op)
   197  		if err != nil {
   198  			return nil, err
   199  		}
   200  		out[i] = value
   201  	}
   202  	return out, nil
   203  }
   204  
   205  func createEmptyValueFor(opcode AggregateOpcode) (sqltypes.Value, error) {
   206  	switch opcode {
   207  	case
   208  		AggregateCountDistinct,
   209  		AggregateCount,
   210  		AggregateCountStar:
   211  		return countZero, nil
   212  	case
   213  		AggregateSumDistinct,
   214  		AggregateSum,
   215  		AggregateMin,
   216  		AggregateMax:
   217  		return sqltypes.NULL, nil
   218  
   219  	}
   220  	return sqltypes.NULL, vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "unknown aggregation %v", opcode)
   221  }
   222  
   223  // Inputs implements the Primitive interface
   224  func (sa *ScalarAggregate) Inputs() []Primitive {
   225  	return []Primitive{sa.Input}
   226  }
   227  
   228  // description implements the Primitive interface
   229  func (sa *ScalarAggregate) description() PrimitiveDescription {
   230  	aggregates := GenericJoin(sa.Aggregates, aggregateParamsToString)
   231  	other := map[string]any{
   232  		"Aggregates": aggregates,
   233  	}
   234  	if sa.TruncateColumnCount > 0 {
   235  		other["ResultColumns"] = sa.TruncateColumnCount
   236  	}
   237  	return PrimitiveDescription{
   238  		OperatorType: "Aggregate",
   239  		Variant:      "Scalar",
   240  		Other:        other,
   241  	}
   242  
   243  }