vitess.io/vitess@v0.16.2/go/vt/vtgate/engine/scalar_aggregation.go (about) 1 /* 2 Copyright 2022 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package engine 18 19 import ( 20 "context" 21 "sync" 22 23 "vitess.io/vitess/go/mysql/collations" 24 "vitess.io/vitess/go/sqltypes" 25 querypb "vitess.io/vitess/go/vt/proto/query" 26 "vitess.io/vitess/go/vt/proto/vtrpc" 27 "vitess.io/vitess/go/vt/vterrors" 28 ) 29 30 var _ Primitive = (*ScalarAggregate)(nil) 31 32 // ScalarAggregate is a primitive used to do aggregations without grouping keys 33 type ScalarAggregate struct { 34 // PreProcess is true if one of the aggregates needs preprocessing. 35 PreProcess bool `json:",omitempty"` 36 37 AggrOnEngine bool 38 39 // Aggregates specifies the aggregation parameters for each 40 // aggregation function: function opcode and input column number. 41 Aggregates []*AggregateParams 42 43 // TruncateColumnCount specifies the number of columns to return 44 // in the final result. Rest of the columns are truncated 45 // from the result received. If 0, no truncation happens. 46 TruncateColumnCount int `json:",omitempty"` 47 48 // Collations stores the collation ID per column offset. 49 // It is used for grouping keys and distinct aggregate functions 50 Collations map[int]collations.ID 51 52 // Input is the primitive that will feed into this Primitive. 53 Input Primitive 54 } 55 56 // RouteType implements the Primitive interface 57 func (sa *ScalarAggregate) RouteType() string { 58 return sa.Input.RouteType() 59 } 60 61 // GetKeyspaceName implements the Primitive interface 62 func (sa *ScalarAggregate) GetKeyspaceName() string { 63 return sa.Input.GetKeyspaceName() 64 65 } 66 67 // GetTableName implements the Primitive interface 68 func (sa *ScalarAggregate) GetTableName() string { 69 return sa.Input.GetTableName() 70 } 71 72 // GetFields implements the Primitive interface 73 func (sa *ScalarAggregate) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { 74 qr, err := sa.Input.GetFields(ctx, vcursor, bindVars) 75 if err != nil { 76 return nil, err 77 } 78 qr = &sqltypes.Result{Fields: convertFields(qr.Fields, sa.PreProcess, sa.Aggregates, sa.AggrOnEngine)} 79 return qr.Truncate(sa.TruncateColumnCount), nil 80 } 81 82 // NeedsTransaction implements the Primitive interface 83 func (sa *ScalarAggregate) NeedsTransaction() bool { 84 return sa.Input.NeedsTransaction() 85 } 86 87 // TryExecute implements the Primitive interface 88 func (sa *ScalarAggregate) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { 89 result, err := vcursor.ExecutePrimitive(ctx, sa.Input, bindVars, wantfields) 90 if err != nil { 91 return nil, err 92 } 93 out := &sqltypes.Result{ 94 Fields: convertFields(result.Fields, sa.PreProcess, sa.Aggregates, sa.AggrOnEngine), 95 } 96 97 var resultRow []sqltypes.Value 98 var curDistincts []sqltypes.Value 99 for _, row := range result.Rows { 100 if resultRow == nil { 101 resultRow, curDistincts = convertRow(row, sa.PreProcess, sa.Aggregates, sa.AggrOnEngine) 102 continue 103 } 104 resultRow, curDistincts, err = merge(result.Fields, resultRow, row, curDistincts, sa.Collations, sa.Aggregates) 105 if err != nil { 106 return nil, err 107 } 108 } 109 110 if resultRow == nil { 111 // When doing aggregation without grouping keys, we need to produce a single row containing zero-value for the 112 // different aggregation functions 113 resultRow, err = sa.createEmptyRow() 114 if err != nil { 115 return nil, err 116 } 117 } else { 118 resultRow, err = convertFinal(resultRow, sa.Aggregates) 119 if err != nil { 120 return nil, err 121 } 122 } 123 124 out.Rows = [][]sqltypes.Value{resultRow} 125 return out.Truncate(sa.TruncateColumnCount), nil 126 } 127 128 // TryStreamExecute implements the Primitive interface 129 func (sa *ScalarAggregate) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { 130 cb := func(qr *sqltypes.Result) error { 131 return callback(qr.Truncate(sa.TruncateColumnCount)) 132 } 133 var current []sqltypes.Value 134 var curDistincts []sqltypes.Value 135 var fields []*querypb.Field 136 fieldsSent := false 137 var mu sync.Mutex 138 139 err := vcursor.StreamExecutePrimitive(ctx, sa.Input, bindVars, wantfields, func(result *sqltypes.Result) error { 140 // as the underlying primitive call is not sync 141 // and here scalar aggregate is using shared variables we have to sync the callback 142 // for correct aggregation. 143 mu.Lock() 144 defer mu.Unlock() 145 if len(result.Fields) != 0 && !fieldsSent { 146 fields = convertFields(result.Fields, sa.PreProcess, sa.Aggregates, sa.AggrOnEngine) 147 if err := cb(&sqltypes.Result{Fields: fields}); err != nil { 148 return err 149 } 150 fieldsSent = true 151 } 152 153 // this code is very similar to the TryExecute method 154 for _, row := range result.Rows { 155 if current == nil { 156 current, curDistincts = convertRow(row, sa.PreProcess, sa.Aggregates, sa.AggrOnEngine) 157 continue 158 } 159 var err error 160 current, curDistincts, err = merge(fields, current, row, curDistincts, sa.Collations, sa.Aggregates) 161 if err != nil { 162 return err 163 } 164 } 165 return nil 166 }) 167 if err != nil { 168 return err 169 } 170 171 if current == nil { 172 // When doing aggregation without grouping keys, we need to produce a single row containing zero-value for the 173 // different aggregation functions 174 current, err = sa.createEmptyRow() 175 if err != nil { 176 return err 177 } 178 } else { 179 current, err = convertFinal(current, sa.Aggregates) 180 if err != nil { 181 return err 182 } 183 } 184 185 return cb(&sqltypes.Result{Rows: [][]sqltypes.Value{current}}) 186 } 187 188 // creates the empty row for the case when we are missing grouping keys and have empty input table 189 func (sa *ScalarAggregate) createEmptyRow() ([]sqltypes.Value, error) { 190 out := make([]sqltypes.Value, len(sa.Aggregates)) 191 for i, aggr := range sa.Aggregates { 192 op := aggr.Opcode 193 if aggr.OrigOpcode != AggregateUnassigned { 194 op = aggr.OrigOpcode 195 } 196 value, err := createEmptyValueFor(op) 197 if err != nil { 198 return nil, err 199 } 200 out[i] = value 201 } 202 return out, nil 203 } 204 205 func createEmptyValueFor(opcode AggregateOpcode) (sqltypes.Value, error) { 206 switch opcode { 207 case 208 AggregateCountDistinct, 209 AggregateCount, 210 AggregateCountStar: 211 return countZero, nil 212 case 213 AggregateSumDistinct, 214 AggregateSum, 215 AggregateMin, 216 AggregateMax: 217 return sqltypes.NULL, nil 218 219 } 220 return sqltypes.NULL, vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "unknown aggregation %v", opcode) 221 } 222 223 // Inputs implements the Primitive interface 224 func (sa *ScalarAggregate) Inputs() []Primitive { 225 return []Primitive{sa.Input} 226 } 227 228 // description implements the Primitive interface 229 func (sa *ScalarAggregate) description() PrimitiveDescription { 230 aggregates := GenericJoin(sa.Aggregates, aggregateParamsToString) 231 other := map[string]any{ 232 "Aggregates": aggregates, 233 } 234 if sa.TruncateColumnCount > 0 { 235 other["ResultColumns"] = sa.TruncateColumnCount 236 } 237 return PrimitiveDescription{ 238 OperatorType: "Aggregate", 239 Variant: "Scalar", 240 Other: other, 241 } 242 243 }