vitess.io/vitess@v0.16.2/go/vt/vtgate/engine/ordered_aggregate.go (about) 1 /* 2 Copyright 2019 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package engine 18 19 import ( 20 "context" 21 "fmt" 22 "strconv" 23 24 "vitess.io/vitess/go/mysql/collations" 25 26 "vitess.io/vitess/go/vt/sqlparser" 27 28 "google.golang.org/protobuf/proto" 29 30 "vitess.io/vitess/go/sqltypes" 31 "vitess.io/vitess/go/vt/vtgate/evalengine" 32 33 binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata" 34 querypb "vitess.io/vitess/go/vt/proto/query" 35 ) 36 37 var _ Primitive = (*OrderedAggregate)(nil) 38 39 // OrderedAggregate is a primitive that expects the underlying primitive 40 // to feed results in an order sorted by the Keys. Rows with duplicate 41 // keys are aggregated using the Aggregate functions. The assumption 42 // is that the underlying primitive is a scatter select with pre-sorted 43 // rows. 44 type OrderedAggregate struct { 45 // PreProcess is true if one of the aggregates needs preprocessing. 46 PreProcess bool `json:",omitempty"` 47 48 // Aggregates specifies the aggregation parameters for each 49 // aggregation function: function opcode and input column number. 50 Aggregates []*AggregateParams 51 52 AggrOnEngine bool 53 54 // GroupByKeys specifies the input values that must be used for 55 // the aggregation key. 56 GroupByKeys []*GroupByParams 57 58 // TruncateColumnCount specifies the number of columns to return 59 // in the final result. Rest of the columns are truncated 60 // from the result received. If 0, no truncation happens. 61 TruncateColumnCount int `json:",omitempty"` 62 63 // Collations stores the collation ID per column offset. 64 // It is used for grouping keys and distinct aggregate functions 65 Collations map[int]collations.ID 66 67 // Input is the primitive that will feed into this Primitive. 68 Input Primitive 69 } 70 71 // GroupByParams specify the grouping key to be used. 72 type GroupByParams struct { 73 KeyCol int 74 WeightStringCol int 75 Expr sqlparser.Expr 76 FromGroupBy bool 77 CollationID collations.ID 78 } 79 80 // String returns a string. Used for plan descriptions 81 func (gbp GroupByParams) String() string { 82 var out string 83 if gbp.WeightStringCol == -1 || gbp.KeyCol == gbp.WeightStringCol { 84 out = strconv.Itoa(gbp.KeyCol) 85 } else { 86 out = fmt.Sprintf("(%d|%d)", gbp.KeyCol, gbp.WeightStringCol) 87 } 88 89 if gbp.CollationID != collations.Unknown { 90 collation := collations.Local().LookupByID(gbp.CollationID) 91 out += " COLLATE " + collation.Name() 92 } 93 94 return out 95 } 96 97 // AggregateParams specify the parameters for each aggregation. 98 // It contains the opcode and input column number. 99 type AggregateParams struct { 100 Opcode AggregateOpcode 101 Col int 102 103 // These are used only for distinct opcodes. 104 KeyCol int 105 WCol int 106 WAssigned bool 107 CollationID collations.ID 108 109 Alias string `json:",omitempty"` 110 Expr sqlparser.Expr 111 Original *sqlparser.AliasedExpr 112 113 // This is based on the function passed in the select expression and 114 // not what we use to aggregate at the engine primitive level. 115 OrigOpcode AggregateOpcode 116 } 117 118 func (ap *AggregateParams) isDistinct() bool { 119 return ap.Opcode == AggregateCountDistinct || ap.Opcode == AggregateSumDistinct 120 } 121 122 func (ap *AggregateParams) preProcess() bool { 123 return ap.Opcode == AggregateCountDistinct || ap.Opcode == AggregateSumDistinct || ap.Opcode == AggregateGtid || ap.Opcode == AggregateCount 124 } 125 126 func (ap *AggregateParams) String() string { 127 keyCol := strconv.Itoa(ap.Col) 128 if ap.WAssigned { 129 keyCol = fmt.Sprintf("%s|%d", keyCol, ap.WCol) 130 } 131 if ap.CollationID != collations.Unknown { 132 collation := collations.Local().LookupByID(ap.CollationID) 133 keyCol += " COLLATE " + collation.Name() 134 } 135 dispOrigOp := "" 136 if ap.OrigOpcode != AggregateUnassigned && ap.OrigOpcode != ap.Opcode { 137 dispOrigOp = "_" + ap.OrigOpcode.String() 138 } 139 if ap.Alias != "" { 140 return fmt.Sprintf("%s%s(%s) AS %s", ap.Opcode.String(), dispOrigOp, keyCol, ap.Alias) 141 } 142 return fmt.Sprintf("%s%s(%s)", ap.Opcode.String(), dispOrigOp, keyCol) 143 } 144 145 // AggregateOpcode is the aggregation Opcode. 146 type AggregateOpcode int 147 148 // These constants list the possible aggregate opcodes. 149 const ( 150 AggregateUnassigned = AggregateOpcode(iota) 151 AggregateCount 152 AggregateSum 153 AggregateMin 154 AggregateMax 155 AggregateCountDistinct 156 AggregateSumDistinct 157 AggregateGtid 158 AggregateRandom 159 AggregateCountStar 160 ) 161 162 var ( 163 // OpcodeType keeps track of the known output types for different aggregate functions 164 OpcodeType = map[AggregateOpcode]querypb.Type{ 165 AggregateCountDistinct: sqltypes.Int64, 166 AggregateCount: sqltypes.Int64, 167 AggregateCountStar: sqltypes.Int64, 168 AggregateSumDistinct: sqltypes.Decimal, 169 AggregateSum: sqltypes.Decimal, 170 AggregateGtid: sqltypes.VarChar, 171 } 172 // Some predefined values 173 countZero = sqltypes.MakeTrusted(sqltypes.Int64, []byte("0")) 174 countOne = sqltypes.MakeTrusted(sqltypes.Int64, []byte("1")) 175 sumZero = sqltypes.MakeTrusted(sqltypes.Decimal, []byte("0")) 176 ) 177 178 // SupportedAggregates maps the list of supported aggregate 179 // functions to their opcodes. 180 var SupportedAggregates = map[string]AggregateOpcode{ 181 "count": AggregateCount, 182 "sum": AggregateSum, 183 "min": AggregateMin, 184 "max": AggregateMax, 185 // These functions don't exist in mysql, but are used 186 // to display the plan. 187 "count_distinct": AggregateCountDistinct, 188 "sum_distinct": AggregateSumDistinct, 189 "vgtid": AggregateGtid, 190 "count_star": AggregateCountStar, 191 "random": AggregateRandom, 192 } 193 194 func (code AggregateOpcode) String() string { 195 for k, v := range SupportedAggregates { 196 if v == code { 197 return k 198 } 199 } 200 return "ERROR" 201 } 202 203 // MarshalJSON serializes the AggregateOpcode as a JSON string. 204 // It's used for testing and diagnostics. 205 func (code AggregateOpcode) MarshalJSON() ([]byte, error) { 206 return ([]byte)(fmt.Sprintf("\"%s\"", code.String())), nil 207 } 208 209 // RouteType returns a description of the query routing type used by the primitive 210 func (oa *OrderedAggregate) RouteType() string { 211 return oa.Input.RouteType() 212 } 213 214 // GetKeyspaceName specifies the Keyspace that this primitive routes to. 215 func (oa *OrderedAggregate) GetKeyspaceName() string { 216 return oa.Input.GetKeyspaceName() 217 } 218 219 // GetTableName specifies the table that this primitive routes to. 220 func (oa *OrderedAggregate) GetTableName() string { 221 return oa.Input.GetTableName() 222 } 223 224 // SetTruncateColumnCount sets the truncate column count. 225 func (oa *OrderedAggregate) SetTruncateColumnCount(count int) { 226 oa.TruncateColumnCount = count 227 } 228 229 // TryExecute is a Primitive function. 230 func (oa *OrderedAggregate) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { 231 qr, err := oa.execute(ctx, vcursor, bindVars, wantfields) 232 if err != nil { 233 return nil, err 234 } 235 return qr.Truncate(oa.TruncateColumnCount), nil 236 } 237 238 func (oa *OrderedAggregate) execute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { 239 result, err := vcursor.ExecutePrimitive(ctx, oa.Input, bindVars, wantfields) 240 if err != nil { 241 return nil, err 242 } 243 out := &sqltypes.Result{ 244 Fields: convertFields(result.Fields, oa.PreProcess, oa.Aggregates, oa.AggrOnEngine), 245 Rows: make([][]sqltypes.Value, 0, len(result.Rows)), 246 } 247 // This code is similar to the one in StreamExecute. 248 var current []sqltypes.Value 249 var curDistincts []sqltypes.Value 250 for _, row := range result.Rows { 251 if current == nil { 252 current, curDistincts = convertRow(row, oa.PreProcess, oa.Aggregates, oa.AggrOnEngine) 253 continue 254 } 255 equal, err := oa.keysEqual(current, row, oa.Collations) 256 if err != nil { 257 return nil, err 258 } 259 260 if equal { 261 current, curDistincts, err = merge(result.Fields, current, row, curDistincts, oa.Collations, oa.Aggregates) 262 if err != nil { 263 return nil, err 264 } 265 continue 266 } 267 out.Rows = append(out.Rows, current) 268 current, curDistincts = convertRow(row, oa.PreProcess, oa.Aggregates, oa.AggrOnEngine) 269 } 270 271 if current != nil { 272 final, err := convertFinal(current, oa.Aggregates) 273 if err != nil { 274 return nil, err 275 } 276 out.Rows = append(out.Rows, final) 277 } 278 return out, nil 279 } 280 281 // TryStreamExecute is a Primitive function. 282 func (oa *OrderedAggregate) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { 283 var current []sqltypes.Value 284 var curDistincts []sqltypes.Value 285 var fields []*querypb.Field 286 287 cb := func(qr *sqltypes.Result) error { 288 return callback(qr.Truncate(oa.TruncateColumnCount)) 289 } 290 291 err := vcursor.StreamExecutePrimitive(ctx, oa.Input, bindVars, wantfields, func(qr *sqltypes.Result) error { 292 if len(qr.Fields) != 0 { 293 fields = convertFields(qr.Fields, oa.PreProcess, oa.Aggregates, oa.AggrOnEngine) 294 if err := cb(&sqltypes.Result{Fields: fields}); err != nil { 295 return err 296 } 297 } 298 // This code is similar to the one in Execute. 299 for _, row := range qr.Rows { 300 if current == nil { 301 current, curDistincts = convertRow(row, oa.PreProcess, oa.Aggregates, oa.AggrOnEngine) 302 continue 303 } 304 305 equal, err := oa.keysEqual(current, row, oa.Collations) 306 if err != nil { 307 return err 308 } 309 310 if equal { 311 current, curDistincts, err = merge(fields, current, row, curDistincts, oa.Collations, oa.Aggregates) 312 if err != nil { 313 return err 314 } 315 continue 316 } 317 if err := cb(&sqltypes.Result{Rows: [][]sqltypes.Value{current}}); err != nil { 318 return err 319 } 320 current, curDistincts = convertRow(row, oa.PreProcess, oa.Aggregates, oa.AggrOnEngine) 321 } 322 return nil 323 }) 324 if err != nil { 325 return err 326 } 327 328 if current != nil { 329 if err := cb(&sqltypes.Result{Rows: [][]sqltypes.Value{current}}); err != nil { 330 return err 331 } 332 } 333 return nil 334 } 335 336 func convertFields(fields []*querypb.Field, preProcess bool, aggrs []*AggregateParams, aggrOnEngine bool) []*querypb.Field { 337 if !preProcess { 338 return fields 339 } 340 for _, aggr := range aggrs { 341 if !aggr.preProcess() && !aggrOnEngine { 342 continue 343 } 344 fields[aggr.Col] = &querypb.Field{ 345 Name: aggr.Alias, 346 Type: OpcodeType[aggr.Opcode], 347 } 348 if aggr.isDistinct() { 349 aggr.KeyCol = aggr.Col 350 } 351 } 352 return fields 353 } 354 355 func convertRow(row []sqltypes.Value, preProcess bool, aggregates []*AggregateParams, aggrOnEngine bool) (newRow []sqltypes.Value, curDistincts []sqltypes.Value) { 356 if !preProcess { 357 return row, nil 358 } 359 newRow = append(newRow, row...) 360 curDistincts = make([]sqltypes.Value, len(aggregates)) 361 for index, aggr := range aggregates { 362 switch aggr.Opcode { 363 case AggregateCountStar: 364 newRow[aggr.Col] = countOne 365 case AggregateCount: 366 val := countOne 367 if row[aggr.Col].IsNull() { 368 val = countZero 369 } 370 newRow[aggr.Col] = val 371 case AggregateCountDistinct: 372 curDistincts[index] = findComparableCurrentDistinct(row, aggr) 373 // Type is int64. Ok to call MakeTrusted. 374 if row[aggr.KeyCol].IsNull() { 375 newRow[aggr.Col] = countZero 376 } else { 377 newRow[aggr.Col] = countOne 378 } 379 case AggregateSum: 380 if !aggrOnEngine { 381 break 382 } 383 if row[aggr.Col].IsNull() { 384 break 385 } 386 var err error 387 newRow[aggr.Col], err = evalengine.Cast(row[aggr.Col], OpcodeType[aggr.Opcode]) 388 if err != nil { 389 newRow[aggr.Col] = sumZero 390 } 391 case AggregateSumDistinct: 392 curDistincts[index] = findComparableCurrentDistinct(row, aggr) 393 var err error 394 newRow[aggr.Col], err = evalengine.Cast(row[aggr.Col], OpcodeType[aggr.Opcode]) 395 if err != nil { 396 newRow[aggr.Col] = sumZero 397 } 398 case AggregateGtid: 399 vgtid := &binlogdatapb.VGtid{} 400 vgtid.ShardGtids = append(vgtid.ShardGtids, &binlogdatapb.ShardGtid{ 401 Keyspace: row[aggr.Col-1].ToString(), 402 Shard: row[aggr.Col+1].ToString(), 403 Gtid: row[aggr.Col].ToString(), 404 }) 405 data, _ := proto.Marshal(vgtid) 406 val, _ := sqltypes.NewValue(sqltypes.VarBinary, data) 407 newRow[aggr.Col] = val 408 } 409 } 410 return newRow, curDistincts 411 } 412 413 func findComparableCurrentDistinct(row []sqltypes.Value, aggr *AggregateParams) sqltypes.Value { 414 curDistinct := row[aggr.KeyCol] 415 if aggr.WAssigned && !curDistinct.IsComparable() { 416 aggr.KeyCol = aggr.WCol 417 curDistinct = row[aggr.KeyCol] 418 } 419 return curDistinct 420 } 421 422 // GetFields is a Primitive function. 423 func (oa *OrderedAggregate) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { 424 qr, err := oa.Input.GetFields(ctx, vcursor, bindVars) 425 if err != nil { 426 return nil, err 427 } 428 qr = &sqltypes.Result{Fields: convertFields(qr.Fields, oa.PreProcess, oa.Aggregates, oa.AggrOnEngine)} 429 return qr.Truncate(oa.TruncateColumnCount), nil 430 } 431 432 // Inputs returns the Primitive input for this aggregation 433 func (oa *OrderedAggregate) Inputs() []Primitive { 434 return []Primitive{oa.Input} 435 } 436 437 // NeedsTransaction implements the Primitive interface 438 func (oa *OrderedAggregate) NeedsTransaction() bool { 439 return oa.Input.NeedsTransaction() 440 } 441 442 func (oa *OrderedAggregate) keysEqual(row1, row2 []sqltypes.Value, colls map[int]collations.ID) (bool, error) { 443 for _, key := range oa.GroupByKeys { 444 cmp, err := evalengine.NullsafeCompare(row1[key.KeyCol], row2[key.KeyCol], colls[key.KeyCol]) 445 if err != nil { 446 _, isComparisonErr := err.(evalengine.UnsupportedComparisonError) 447 _, isCollationErr := err.(evalengine.UnsupportedCollationError) 448 if !isComparisonErr && !isCollationErr || key.WeightStringCol == -1 { 449 return false, err 450 } 451 key.KeyCol = key.WeightStringCol 452 cmp, err = evalengine.NullsafeCompare(row1[key.WeightStringCol], row2[key.WeightStringCol], colls[key.KeyCol]) 453 if err != nil { 454 return false, err 455 } 456 } 457 if cmp != 0 { 458 return false, nil 459 } 460 } 461 return true, nil 462 } 463 464 func merge( 465 fields []*querypb.Field, 466 row1, row2 []sqltypes.Value, 467 curDistincts []sqltypes.Value, 468 colls map[int]collations.ID, 469 aggregates []*AggregateParams, 470 ) ([]sqltypes.Value, []sqltypes.Value, error) { 471 result := sqltypes.CopyRow(row1) 472 for index, aggr := range aggregates { 473 if aggr.isDistinct() { 474 if row2[aggr.KeyCol].IsNull() { 475 continue 476 } 477 cmp, err := evalengine.NullsafeCompare(curDistincts[index], row2[aggr.KeyCol], colls[aggr.KeyCol]) 478 if err != nil { 479 return nil, nil, err 480 } 481 if cmp == 0 { 482 continue 483 } 484 curDistincts[index] = findComparableCurrentDistinct(row2, aggr) 485 } 486 var err error 487 switch aggr.Opcode { 488 case AggregateCountStar: 489 value := row1[aggr.Col] 490 result[aggr.Col], err = evalengine.NullSafeAdd(value, countOne, fields[aggr.Col].Type) 491 case AggregateCount: 492 val := countOne 493 if row2[aggr.Col].IsNull() { 494 val = countZero 495 } 496 result[aggr.Col], err = evalengine.NullSafeAdd(row1[aggr.Col], val, fields[aggr.Col].Type) 497 case AggregateSum: 498 value := row1[aggr.Col] 499 v2 := row2[aggr.Col] 500 if value.IsNull() && v2.IsNull() { 501 result[aggr.Col] = sqltypes.NULL 502 break 503 } 504 result[aggr.Col], err = evalengine.NullSafeAdd(value, v2, fields[aggr.Col].Type) 505 case AggregateMin: 506 result[aggr.Col], err = evalengine.Min(row1[aggr.Col], row2[aggr.Col], colls[aggr.Col]) 507 case AggregateMax: 508 result[aggr.Col], err = evalengine.Max(row1[aggr.Col], row2[aggr.Col], colls[aggr.Col]) 509 case AggregateCountDistinct: 510 result[aggr.Col], err = evalengine.NullSafeAdd(row1[aggr.Col], countOne, OpcodeType[aggr.Opcode]) 511 case AggregateSumDistinct: 512 result[aggr.Col], err = evalengine.NullSafeAdd(row1[aggr.Col], row2[aggr.Col], OpcodeType[aggr.Opcode]) 513 case AggregateGtid: 514 vgtid := &binlogdatapb.VGtid{} 515 rowBytes, err := row1[aggr.Col].ToBytes() 516 if err != nil { 517 return nil, nil, err 518 } 519 err = proto.Unmarshal(rowBytes, vgtid) 520 if err != nil { 521 return nil, nil, err 522 } 523 vgtid.ShardGtids = append(vgtid.ShardGtids, &binlogdatapb.ShardGtid{ 524 Keyspace: row2[aggr.Col-1].ToString(), 525 Shard: row2[aggr.Col+1].ToString(), 526 Gtid: row2[aggr.Col].ToString(), 527 }) 528 data, _ := proto.Marshal(vgtid) 529 val, _ := sqltypes.NewValue(sqltypes.VarBinary, data) 530 result[aggr.Col] = val 531 case AggregateRandom: 532 // we just grab the first value per grouping. no need to do anything more complicated here 533 default: 534 return nil, nil, fmt.Errorf("BUG: Unexpected opcode: %v", aggr.Opcode) 535 } 536 if err != nil { 537 return nil, nil, err 538 } 539 } 540 return result, curDistincts, nil 541 } 542 543 func aggregateParamsToString(in any) string { 544 return in.(*AggregateParams).String() 545 } 546 547 func groupByParamsToString(i any) string { 548 return i.(*GroupByParams).String() 549 } 550 551 func (oa *OrderedAggregate) description() PrimitiveDescription { 552 aggregates := GenericJoin(oa.Aggregates, aggregateParamsToString) 553 groupBy := GenericJoin(oa.GroupByKeys, groupByParamsToString) 554 other := map[string]any{ 555 "Aggregates": aggregates, 556 "GroupBy": groupBy, 557 } 558 if oa.TruncateColumnCount > 0 { 559 other["ResultColumns"] = oa.TruncateColumnCount 560 } 561 return PrimitiveDescription{ 562 OperatorType: "Aggregate", 563 Variant: "Ordered", 564 Other: other, 565 } 566 } 567 568 func convertFinal(current []sqltypes.Value, aggregates []*AggregateParams) ([]sqltypes.Value, error) { 569 result := sqltypes.CopyRow(current) 570 for _, aggr := range aggregates { 571 switch aggr.Opcode { 572 case AggregateGtid: 573 vgtid := &binlogdatapb.VGtid{} 574 currentBytes, err := current[aggr.Col].ToBytes() 575 if err != nil { 576 return nil, err 577 } 578 err = proto.Unmarshal(currentBytes, vgtid) 579 if err != nil { 580 return nil, err 581 } 582 result[aggr.Col] = sqltypes.NewVarChar(vgtid.String()) 583 } 584 } 585 return result, nil 586 }