vitess.io/vitess@v0.16.2/go/vt/vtgate/engine/ordered_aggregate_test.go (about) 1 /* 2 Copyright 2019 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package engine 18 19 import ( 20 "context" 21 "errors" 22 "fmt" 23 "testing" 24 25 "github.com/stretchr/testify/assert" 26 "github.com/stretchr/testify/require" 27 28 "vitess.io/vitess/go/mysql/collations" 29 "vitess.io/vitess/go/sqltypes" 30 "vitess.io/vitess/go/test/utils" 31 32 binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata" 33 querypb "vitess.io/vitess/go/vt/proto/query" 34 "vitess.io/vitess/go/vt/servenv" 35 ) 36 37 var collationEnv *collations.Environment 38 39 func init() { 40 // We require MySQL 8.0 collations for the comparisons in the tests 41 mySQLVersion := "8.0.0" 42 servenv.SetMySQLServerVersionForTest(mySQLVersion) 43 collationEnv = collations.NewEnvironment(mySQLVersion) 44 } 45 46 func TestOrderedAggregateExecute(t *testing.T) { 47 assert := assert.New(t) 48 fields := sqltypes.MakeTestFields( 49 "col|count(*)", 50 "varbinary|decimal", 51 ) 52 fp := &fakePrimitive{ 53 results: []*sqltypes.Result{sqltypes.MakeTestResult( 54 fields, 55 "a|1", 56 "a|1", 57 "b|2", 58 "c|3", 59 "c|4", 60 )}, 61 } 62 63 oa := &OrderedAggregate{ 64 Aggregates: []*AggregateParams{{ 65 Opcode: AggregateSum, 66 Col: 1, 67 }}, 68 GroupByKeys: []*GroupByParams{{KeyCol: 0}}, 69 Input: fp, 70 } 71 72 result, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, false) 73 assert.NoError(err) 74 75 wantResult := sqltypes.MakeTestResult( 76 fields, 77 "a|2", 78 "b|2", 79 "c|7", 80 ) 81 assert.Equal(wantResult, result) 82 } 83 84 func TestOrderedAggregateExecuteTruncate(t *testing.T) { 85 assert := assert.New(t) 86 fp := &fakePrimitive{ 87 results: []*sqltypes.Result{sqltypes.MakeTestResult( 88 sqltypes.MakeTestFields( 89 "col|count(*)|weight_string(col)", 90 "varchar|decimal|varbinary", 91 ), 92 "a|1|A", 93 "A|1|A", 94 "b|2|B", 95 "C|3|C", 96 "c|4|C", 97 )}, 98 } 99 100 oa := &OrderedAggregate{ 101 Aggregates: []*AggregateParams{{ 102 Opcode: AggregateSum, 103 Col: 1, 104 }}, 105 GroupByKeys: []*GroupByParams{{KeyCol: 2}}, 106 TruncateColumnCount: 2, 107 Input: fp, 108 } 109 110 result, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, false) 111 assert.NoError(err) 112 113 wantResult := sqltypes.MakeTestResult( 114 sqltypes.MakeTestFields( 115 "col|count(*)", 116 "varchar|decimal", 117 ), 118 "a|2", 119 "b|2", 120 "C|7", 121 ) 122 assert.Equal(wantResult, result) 123 } 124 125 func TestOrderedAggregateStreamExecute(t *testing.T) { 126 assert := assert.New(t) 127 fields := sqltypes.MakeTestFields( 128 "col|count(*)", 129 "varbinary|decimal", 130 ) 131 fp := &fakePrimitive{ 132 results: []*sqltypes.Result{sqltypes.MakeTestResult( 133 fields, 134 "a|1", 135 "a|1", 136 "b|2", 137 "c|3", 138 "c|4", 139 )}, 140 } 141 142 oa := &OrderedAggregate{ 143 Aggregates: []*AggregateParams{{ 144 Opcode: AggregateSum, 145 Col: 1, 146 }}, 147 GroupByKeys: []*GroupByParams{{KeyCol: 0}}, 148 Input: fp, 149 } 150 151 var results []*sqltypes.Result 152 err := oa.TryStreamExecute(context.Background(), &noopVCursor{}, nil, true, func(qr *sqltypes.Result) error { 153 results = append(results, qr) 154 return nil 155 }) 156 assert.NoError(err) 157 158 wantResults := sqltypes.MakeTestStreamingResults( 159 fields, 160 "a|2", 161 "---", 162 "b|2", 163 "---", 164 "c|7", 165 ) 166 assert.Equal(wantResults, results) 167 } 168 169 func TestOrderedAggregateStreamExecuteTruncate(t *testing.T) { 170 assert := assert.New(t) 171 fp := &fakePrimitive{ 172 results: []*sqltypes.Result{sqltypes.MakeTestResult( 173 sqltypes.MakeTestFields( 174 "col|count(*)|weight_string(col)", 175 "varchar|decimal|varbinary", 176 ), 177 "a|1|A", 178 "A|1|A", 179 "b|2|B", 180 "C|3|C", 181 "c|4|C", 182 )}, 183 } 184 185 oa := &OrderedAggregate{ 186 Aggregates: []*AggregateParams{{ 187 Opcode: AggregateSum, 188 Col: 1, 189 }}, 190 GroupByKeys: []*GroupByParams{{KeyCol: 2}}, 191 TruncateColumnCount: 2, 192 Input: fp, 193 } 194 195 var results []*sqltypes.Result 196 err := oa.TryStreamExecute(context.Background(), &noopVCursor{}, nil, true, func(qr *sqltypes.Result) error { 197 results = append(results, qr) 198 return nil 199 }) 200 assert.NoError(err) 201 202 wantResults := sqltypes.MakeTestStreamingResults( 203 sqltypes.MakeTestFields( 204 "col|count(*)", 205 "varchar|decimal", 206 ), 207 "a|2", 208 "---", 209 "b|2", 210 "---", 211 "C|7", 212 ) 213 assert.Equal(wantResults, results) 214 } 215 216 func TestOrderedAggregateGetFields(t *testing.T) { 217 assert := assert.New(t) 218 input := sqltypes.MakeTestResult( 219 sqltypes.MakeTestFields( 220 "col|count(*)", 221 "varbinary|decimal", 222 ), 223 ) 224 fp := &fakePrimitive{results: []*sqltypes.Result{input}} 225 226 oa := &OrderedAggregate{Input: fp} 227 228 got, err := oa.GetFields(context.Background(), nil, nil) 229 assert.NoError(err) 230 assert.Equal(got, input) 231 } 232 233 func TestOrderedAggregateGetFieldsTruncate(t *testing.T) { 234 assert := assert.New(t) 235 result := sqltypes.MakeTestResult( 236 sqltypes.MakeTestFields( 237 "col|count(*)|weight_string(col)", 238 "varchar|decimal|varbinary", 239 ), 240 ) 241 fp := &fakePrimitive{results: []*sqltypes.Result{result}} 242 243 oa := &OrderedAggregate{ 244 TruncateColumnCount: 2, 245 Input: fp, 246 } 247 248 got, err := oa.GetFields(context.Background(), nil, nil) 249 assert.NoError(err) 250 wantResult := sqltypes.MakeTestResult( 251 sqltypes.MakeTestFields( 252 "col|count(*)", 253 "varchar|decimal", 254 ), 255 ) 256 assert.Equal(wantResult, got) 257 } 258 259 func TestOrderedAggregateInputFail(t *testing.T) { 260 fp := &fakePrimitive{sendErr: errors.New("input fail")} 261 262 oa := &OrderedAggregate{Input: fp} 263 264 want := "input fail" 265 if _, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, false); err == nil || err.Error() != want { 266 t.Errorf("oa.Execute(): %v, want %s", err, want) 267 } 268 269 fp.rewind() 270 if err := oa.TryStreamExecute(context.Background(), &noopVCursor{}, nil, false, func(_ *sqltypes.Result) error { return nil }); err == nil || err.Error() != want { 271 t.Errorf("oa.StreamExecute(): %v, want %s", err, want) 272 } 273 274 fp.rewind() 275 if _, err := oa.GetFields(context.Background(), nil, nil); err == nil || err.Error() != want { 276 t.Errorf("oa.GetFields(): %v, want %s", err, want) 277 } 278 } 279 280 func TestOrderedAggregateExecuteCountDistinct(t *testing.T) { 281 assert := assert.New(t) 282 fp := &fakePrimitive{ 283 results: []*sqltypes.Result{sqltypes.MakeTestResult( 284 sqltypes.MakeTestFields( 285 "col1|col2|count(*)", 286 "varbinary|decimal|int64", 287 ), 288 // Two identical values 289 "a|1|1", 290 "a|1|2", 291 // Single value 292 "b|1|1", 293 // Two different values 294 "c|3|1", 295 "c|4|1", 296 // Single null 297 "d|null|1", 298 // Start with null 299 "e|null|1", 300 "e|1|1", 301 // Null comes after first 302 "f|1|1", 303 "f|null|1", 304 // Identical to non-identical transition 305 "g|1|1", 306 "g|1|1", 307 "g|2|1", 308 "g|3|1", 309 // Non-identical to identical transition 310 "h|1|1", 311 "h|2|1", 312 "h|2|1", 313 "h|3|1", 314 // Key transition, should still count 3 315 "i|3|1", 316 "i|4|1", 317 )}, 318 } 319 320 oa := &OrderedAggregate{ 321 PreProcess: true, 322 Aggregates: []*AggregateParams{{ 323 Opcode: AggregateCountDistinct, 324 Col: 1, 325 Alias: "count(distinct col2)", 326 }, { 327 // Also add a count(*) 328 Opcode: AggregateSum, 329 Col: 2, 330 }}, 331 GroupByKeys: []*GroupByParams{{KeyCol: 0}}, 332 Input: fp, 333 } 334 335 result, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, false) 336 assert.NoError(err) 337 338 wantResult := sqltypes.MakeTestResult( 339 sqltypes.MakeTestFields( 340 "col1|count(distinct col2)|count(*)", 341 "varbinary|int64|int64", 342 ), 343 "a|1|3", 344 "b|1|1", 345 "c|2|2", 346 "d|0|1", 347 "e|1|2", 348 "f|1|2", 349 "g|3|4", 350 "h|3|4", 351 "i|2|2", 352 ) 353 assert.Equal(wantResult, result) 354 } 355 356 func TestOrderedAggregateStreamCountDistinct(t *testing.T) { 357 assert := assert.New(t) 358 fp := &fakePrimitive{ 359 results: []*sqltypes.Result{sqltypes.MakeTestResult( 360 sqltypes.MakeTestFields( 361 "col1|col2|count(*)", 362 "varbinary|decimal|int64", 363 ), 364 // Two identical values 365 "a|1|1", 366 "a|1|2", 367 // Single value 368 "b|1|1", 369 // Two different values 370 "c|3|1", 371 "c|4|1", 372 // Single null 373 "d|null|1", 374 // Start with null 375 "e|null|1", 376 "e|1|1", 377 // Null comes after first 378 "f|1|1", 379 "f|null|1", 380 // Identical to non-identical transition 381 "g|1|1", 382 "g|1|1", 383 "g|2|1", 384 "g|3|1", 385 // Non-identical to identical transition 386 "h|1|1", 387 "h|2|1", 388 "h|2|1", 389 "h|3|1", 390 // Key transition, should still count 3 391 "i|3|1", 392 "i|4|1", 393 )}, 394 } 395 396 oa := &OrderedAggregate{ 397 PreProcess: true, 398 Aggregates: []*AggregateParams{{ 399 Opcode: AggregateCountDistinct, 400 Col: 1, 401 Alias: "count(distinct col2)", 402 }, { 403 // Also add a count(*) 404 Opcode: AggregateSum, 405 Col: 2, 406 }}, 407 GroupByKeys: []*GroupByParams{{KeyCol: 0}}, 408 Input: fp, 409 } 410 411 var results []*sqltypes.Result 412 err := oa.TryStreamExecute(context.Background(), &noopVCursor{}, nil, true, func(qr *sqltypes.Result) error { 413 results = append(results, qr) 414 return nil 415 }) 416 assert.NoError(err) 417 418 wantResults := sqltypes.MakeTestStreamingResults( 419 sqltypes.MakeTestFields( 420 "col1|count(distinct col2)|count(*)", 421 "varbinary|int64|int64", 422 ), 423 "a|1|3", 424 "-----", 425 "b|1|1", 426 "-----", 427 "c|2|2", 428 "-----", 429 "d|0|1", 430 "-----", 431 "e|1|2", 432 "-----", 433 "f|1|2", 434 "-----", 435 "g|3|4", 436 "-----", 437 "h|3|4", 438 "-----", 439 "i|2|2", 440 ) 441 assert.Equal(wantResults, results) 442 } 443 444 func TestOrderedAggregateSumDistinctGood(t *testing.T) { 445 assert := assert.New(t) 446 fp := &fakePrimitive{ 447 results: []*sqltypes.Result{sqltypes.MakeTestResult( 448 sqltypes.MakeTestFields( 449 "col1|col2|sum(col3)", 450 "varbinary|int64|decimal", 451 ), 452 // Two identical values 453 "a|1|1", 454 "a|1|2", 455 // Single value 456 "b|1|1", 457 // Two different values 458 "c|3|1", 459 "c|4|1", 460 // Single null 461 "d|null|1", 462 "d|1|1", 463 // Start with null 464 "e|null|1", 465 "e|1|1", 466 // Null comes after first 467 "f|1|1", 468 "f|null|1", 469 // Identical to non-identical transition 470 "g|1|1", 471 "g|1|1", 472 "g|2|1", 473 "g|3|1", 474 // Non-identical to identical transition 475 "h|1|1", 476 "h|2|1", 477 "h|2|1", 478 "h|3|1", 479 // Key transition, should still count 3 480 "i|3|1", 481 "i|4|1", 482 )}, 483 } 484 485 oa := &OrderedAggregate{ 486 PreProcess: true, 487 Aggregates: []*AggregateParams{{ 488 Opcode: AggregateSumDistinct, 489 Col: 1, 490 Alias: "sum(distinct col2)", 491 }, { 492 // Also add a count(*) 493 Opcode: AggregateSum, 494 Col: 2, 495 }}, 496 GroupByKeys: []*GroupByParams{{KeyCol: 0}}, 497 Input: fp, 498 } 499 500 result, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, false) 501 assert.NoError(err) 502 503 wantResult := sqltypes.MakeTestResult( 504 sqltypes.MakeTestFields( 505 "col1|sum(distinct col2)|sum(col3)", 506 "varbinary|decimal|decimal", 507 ), 508 "a|1|3", 509 "b|1|1", 510 "c|7|2", 511 "d|1|2", 512 "e|1|2", 513 "f|1|2", 514 "g|6|4", 515 "h|6|4", 516 "i|7|2", 517 ) 518 want := fmt.Sprintf("%v", wantResult.Rows) 519 got := fmt.Sprintf("%v", result.Rows) 520 assert.Equal(want, got) 521 } 522 523 func TestOrderedAggregateSumDistinctTolerateError(t *testing.T) { 524 fp := &fakePrimitive{ 525 results: []*sqltypes.Result{sqltypes.MakeTestResult( 526 sqltypes.MakeTestFields( 527 "col1|col2", 528 "varbinary|varbinary", 529 ), 530 "a|aaa", 531 "a|0", 532 "a|1", 533 )}, 534 } 535 536 oa := &OrderedAggregate{ 537 PreProcess: true, 538 Aggregates: []*AggregateParams{{ 539 Opcode: AggregateSumDistinct, 540 Col: 1, 541 Alias: "sum(distinct col2)", 542 }}, 543 GroupByKeys: []*GroupByParams{{KeyCol: 0}}, 544 Input: fp, 545 } 546 547 result, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, false) 548 assert.NoError(t, err) 549 550 wantResult := sqltypes.MakeTestResult( 551 sqltypes.MakeTestFields( 552 "col1|sum(distinct col2)", 553 "varbinary|decimal", 554 ), 555 "a|1", 556 ) 557 utils.MustMatch(t, wantResult, result, "") 558 } 559 560 func TestOrderedAggregateKeysFail(t *testing.T) { 561 fields := sqltypes.MakeTestFields( 562 "col|count(*)", 563 "varchar|decimal", 564 ) 565 fp := &fakePrimitive{ 566 results: []*sqltypes.Result{sqltypes.MakeTestResult( 567 fields, 568 "a|1", 569 "a|1", 570 )}, 571 } 572 573 oa := &OrderedAggregate{ 574 Aggregates: []*AggregateParams{{ 575 Opcode: AggregateSum, 576 Col: 1, 577 }}, 578 GroupByKeys: []*GroupByParams{{KeyCol: 0}}, 579 Input: fp, 580 } 581 582 want := "cannot compare strings, collation is unknown or unsupported (collation ID: 0)" 583 if _, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, false); err == nil || err.Error() != want { 584 t.Errorf("oa.Execute(): %v, want %s", err, want) 585 } 586 587 fp.rewind() 588 if err := oa.TryStreamExecute(context.Background(), &noopVCursor{}, nil, false, func(_ *sqltypes.Result) error { return nil }); err == nil || err.Error() != want { 589 t.Errorf("oa.StreamExecute(): %v, want %s", err, want) 590 } 591 } 592 593 func TestOrderedAggregateMergeFail(t *testing.T) { 594 fields := sqltypes.MakeTestFields( 595 "col|count(*)", 596 "varbinary|decimal", 597 ) 598 fp := &fakePrimitive{ 599 results: []*sqltypes.Result{sqltypes.MakeTestResult( 600 fields, 601 "a|1", 602 "a|0", 603 )}, 604 } 605 606 oa := &OrderedAggregate{ 607 Aggregates: []*AggregateParams{{ 608 Opcode: AggregateSum, 609 Col: 1, 610 }}, 611 GroupByKeys: []*GroupByParams{{KeyCol: 0}}, 612 Input: fp, 613 } 614 615 result := &sqltypes.Result{ 616 Fields: []*querypb.Field{ 617 { 618 Name: "col", 619 Type: querypb.Type_VARBINARY, 620 }, 621 { 622 Name: "count(*)", 623 Type: querypb.Type_DECIMAL, 624 }, 625 }, 626 Rows: [][]sqltypes.Value{ 627 { 628 sqltypes.MakeTrusted(querypb.Type_VARBINARY, []byte("a")), 629 sqltypes.MakeTrusted(querypb.Type_DECIMAL, []byte("1")), 630 }, 631 }, 632 } 633 634 res, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, false) 635 require.NoError(t, err) 636 637 utils.MustMatch(t, result, res, "Found mismatched values") 638 639 fp.rewind() 640 err = oa.TryStreamExecute(context.Background(), &noopVCursor{}, nil, true, func(_ *sqltypes.Result) error { return nil }) 641 require.NoError(t, err) 642 } 643 644 func TestMerge(t *testing.T) { 645 assert := assert.New(t) 646 oa := &OrderedAggregate{ 647 Aggregates: []*AggregateParams{{ 648 Opcode: AggregateSum, 649 Col: 1, 650 }, { 651 Opcode: AggregateSum, 652 Col: 2, 653 }, { 654 Opcode: AggregateMin, 655 Col: 3, 656 }, { 657 Opcode: AggregateMax, 658 Col: 4, 659 }}, 660 } 661 fields := sqltypes.MakeTestFields( 662 "a|b|c|d|e", 663 "int64|int64|decimal|in32|varbinary", 664 ) 665 r := sqltypes.MakeTestResult(fields, 666 "1|2|3.2|3|ab", 667 "1|3|2.8|2|bc", 668 ) 669 670 merged, _, err := merge(fields, r.Rows[0], r.Rows[1], nil, nil, oa.Aggregates) 671 assert.NoError(err) 672 want := sqltypes.MakeTestResult(fields, "1|5|6.0|2|bc").Rows[0] 673 assert.Equal(want, merged) 674 675 // swap and retry 676 merged, _, err = merge(fields, r.Rows[1], r.Rows[0], nil, nil, oa.Aggregates) 677 assert.NoError(err) 678 assert.Equal(want, merged) 679 } 680 681 func TestOrderedAggregateExecuteGtid(t *testing.T) { 682 vgtid := binlogdatapb.VGtid{} 683 vgtid.ShardGtids = append(vgtid.ShardGtids, &binlogdatapb.ShardGtid{ 684 Keyspace: "ks", 685 Shard: "-80", 686 Gtid: "a", 687 }) 688 vgtid.ShardGtids = append(vgtid.ShardGtids, &binlogdatapb.ShardGtid{ 689 Keyspace: "ks", 690 Shard: "80-", 691 Gtid: "b", 692 }) 693 694 fp := &fakePrimitive{ 695 results: []*sqltypes.Result{sqltypes.MakeTestResult( 696 sqltypes.MakeTestFields( 697 "keyspace|gtid|shard", 698 "varchar|varchar|varchar", 699 ), 700 "ks|a|-40", 701 "ks|b|40-80", 702 "ks|c|80-c0", 703 "ks|d|c0-", 704 )}, 705 } 706 707 oa := &OrderedAggregate{ 708 PreProcess: true, 709 Aggregates: []*AggregateParams{{ 710 Opcode: AggregateGtid, 711 Col: 1, 712 Alias: "vgtid", 713 }}, 714 TruncateColumnCount: 2, 715 Input: fp, 716 } 717 718 result, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, false) 719 require.NoError(t, err) 720 721 wantResult := sqltypes.MakeTestResult( 722 sqltypes.MakeTestFields( 723 "keyspace|vgtid", 724 "varchar|varchar", 725 ), 726 `ks|shard_gtids:{keyspace:"ks" shard:"-40" gtid:"a"} shard_gtids:{keyspace:"ks" shard:"40-80" gtid:"b"} shard_gtids:{keyspace:"ks" shard:"80-c0" gtid:"c"} shard_gtids:{keyspace:"ks" shard:"c0-" gtid:"d"}`, 727 ) 728 assert.Equal(t, wantResult, result) 729 } 730 731 func TestCountDistinctOnVarchar(t *testing.T) { 732 fields := sqltypes.MakeTestFields( 733 "c1|c2|weight_string(c2)", 734 "int64|varchar|varbinary", 735 ) 736 fp := &fakePrimitive{ 737 results: []*sqltypes.Result{sqltypes.MakeTestResult( 738 fields, 739 "10|a|0x41", 740 "10|a|0x41", 741 "10|b|0x42", 742 "20|b|0x42", 743 )}, 744 } 745 746 oa := &OrderedAggregate{ 747 PreProcess: true, 748 Aggregates: []*AggregateParams{{ 749 Opcode: AggregateCountDistinct, 750 Col: 1, 751 WCol: 2, 752 WAssigned: true, 753 Alias: "count(distinct c2)", 754 }}, 755 GroupByKeys: []*GroupByParams{{KeyCol: 0}}, 756 Input: fp, 757 TruncateColumnCount: 2, 758 } 759 760 want := sqltypes.MakeTestResult( 761 sqltypes.MakeTestFields( 762 "c1|count(distinct c2)", 763 "int64|int64", 764 ), 765 `10|2`, 766 `20|1`, 767 ) 768 769 qr, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, false) 770 require.NoError(t, err) 771 assert.Equal(t, want, qr) 772 773 fp.rewind() 774 results := &sqltypes.Result{} 775 err = oa.TryStreamExecute(context.Background(), &noopVCursor{}, nil, true, func(qr *sqltypes.Result) error { 776 if qr.Fields != nil { 777 results.Fields = qr.Fields 778 } 779 results.Rows = append(results.Rows, qr.Rows...) 780 return nil 781 }) 782 require.NoError(t, err) 783 assert.Equal(t, want, results) 784 } 785 786 func TestCountDistinctOnVarcharWithNulls(t *testing.T) { 787 fields := sqltypes.MakeTestFields( 788 "c1|c2|weight_string(c2)", 789 "int64|varchar|varbinary", 790 ) 791 fp := &fakePrimitive{ 792 results: []*sqltypes.Result{sqltypes.MakeTestResult( 793 fields, 794 "null|null|null", 795 "null|a|0x41", 796 "null|b|0x42", 797 "10|null|null", 798 "10|null|null", 799 "10|a|0x41", 800 "10|a|0x41", 801 "10|b|0x42", 802 "20|null|null", 803 "20|b|0x42", 804 "30|null|null", 805 "30|null|null", 806 "30|null|null", 807 "30|null|null", 808 )}, 809 } 810 811 oa := &OrderedAggregate{ 812 PreProcess: true, 813 Aggregates: []*AggregateParams{{ 814 Opcode: AggregateCountDistinct, 815 Col: 1, 816 WCol: 2, 817 WAssigned: true, 818 Alias: "count(distinct c2)", 819 }}, 820 GroupByKeys: []*GroupByParams{{KeyCol: 0}}, 821 Input: fp, 822 TruncateColumnCount: 2, 823 } 824 825 want := sqltypes.MakeTestResult( 826 sqltypes.MakeTestFields( 827 "c1|count(distinct c2)", 828 "int64|int64", 829 ), 830 `null|2`, 831 `10|2`, 832 `20|1`, 833 `30|0`, 834 ) 835 836 qr, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, false) 837 require.NoError(t, err) 838 assert.Equal(t, want, qr) 839 840 fp.rewind() 841 results := &sqltypes.Result{} 842 err = oa.TryStreamExecute(context.Background(), &noopVCursor{}, nil, true, func(qr *sqltypes.Result) error { 843 if qr.Fields != nil { 844 results.Fields = qr.Fields 845 } 846 results.Rows = append(results.Rows, qr.Rows...) 847 return nil 848 }) 849 require.NoError(t, err) 850 assert.Equal(t, want, results) 851 } 852 853 func TestSumDistinctOnVarcharWithNulls(t *testing.T) { 854 fields := sqltypes.MakeTestFields( 855 "c1|c2|weight_string(c2)", 856 "int64|varchar|varbinary", 857 ) 858 fp := &fakePrimitive{ 859 results: []*sqltypes.Result{sqltypes.MakeTestResult( 860 fields, 861 "null|null|null", 862 "null|a|0x41", 863 "null|b|0x42", 864 "10|null|null", 865 "10|null|null", 866 "10|a|0x41", 867 "10|a|0x41", 868 "10|b|0x42", 869 "20|null|null", 870 "20|b|0x42", 871 "30|null|null", 872 "30|null|null", 873 "30|null|null", 874 "30|null|null", 875 )}, 876 } 877 878 oa := &OrderedAggregate{ 879 PreProcess: true, 880 Aggregates: []*AggregateParams{{ 881 Opcode: AggregateSumDistinct, 882 Col: 1, 883 WCol: 2, 884 WAssigned: true, 885 Alias: "sum(distinct c2)", 886 }}, 887 GroupByKeys: []*GroupByParams{{KeyCol: 0}}, 888 Input: fp, 889 TruncateColumnCount: 2, 890 } 891 892 want := sqltypes.MakeTestResult( 893 sqltypes.MakeTestFields( 894 "c1|sum(distinct c2)", 895 "int64|decimal", 896 ), 897 `null|0`, 898 `10|0`, 899 `20|0`, 900 `30|null`, 901 ) 902 903 qr, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, false) 904 require.NoError(t, err) 905 assert.Equal(t, want, qr) 906 907 fp.rewind() 908 results := &sqltypes.Result{} 909 err = oa.TryStreamExecute(context.Background(), &noopVCursor{}, nil, true, func(qr *sqltypes.Result) error { 910 if qr.Fields != nil { 911 results.Fields = qr.Fields 912 } 913 results.Rows = append(results.Rows, qr.Rows...) 914 return nil 915 }) 916 require.NoError(t, err) 917 assert.Equal(t, want, results) 918 } 919 920 func TestMultiDistinct(t *testing.T) { 921 fields := sqltypes.MakeTestFields( 922 "c1|c2|c3", 923 "int64|int64|int64", 924 ) 925 fp := &fakePrimitive{ 926 results: []*sqltypes.Result{sqltypes.MakeTestResult( 927 fields, 928 "null|null|null", 929 "null|1|2", 930 "null|2|2", 931 "10|null|null", 932 "10|2|null", 933 "10|2|1", 934 "10|2|3", 935 "10|3|3", 936 "20|null|null", 937 "20|null|null", 938 "30|1|1", 939 "30|1|2", 940 "30|1|3", 941 "40|1|1", 942 "40|2|1", 943 "40|3|1", 944 )}, 945 } 946 947 oa := &OrderedAggregate{ 948 PreProcess: true, 949 Aggregates: []*AggregateParams{{ 950 Opcode: AggregateCountDistinct, 951 Col: 1, 952 Alias: "count(distinct c2)", 953 }, { 954 Opcode: AggregateSumDistinct, 955 Col: 2, 956 Alias: "sum(distinct c3)", 957 }}, 958 GroupByKeys: []*GroupByParams{{KeyCol: 0}}, 959 Input: fp, 960 } 961 962 want := sqltypes.MakeTestResult( 963 sqltypes.MakeTestFields( 964 "c1|count(distinct c2)|sum(distinct c3)", 965 "int64|int64|decimal", 966 ), 967 `null|2|2`, 968 `10|2|4`, 969 `20|0|null`, 970 `30|1|6`, 971 `40|3|1`, 972 ) 973 974 qr, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, false) 975 require.NoError(t, err) 976 assert.Equal(t, want, qr) 977 978 fp.rewind() 979 results := &sqltypes.Result{} 980 err = oa.TryStreamExecute(context.Background(), &noopVCursor{}, nil, true, func(qr *sqltypes.Result) error { 981 if qr.Fields != nil { 982 results.Fields = qr.Fields 983 } 984 results.Rows = append(results.Rows, qr.Rows...) 985 return nil 986 }) 987 require.NoError(t, err) 988 assert.Equal(t, want, results) 989 } 990 991 func TestOrderedAggregateCollate(t *testing.T) { 992 assert := assert.New(t) 993 fields := sqltypes.MakeTestFields( 994 "col|count(*)", 995 "varchar|decimal", 996 ) 997 fp := &fakePrimitive{ 998 results: []*sqltypes.Result{sqltypes.MakeTestResult( 999 fields, 1000 "a|1", 1001 "A|1", 1002 "Ǎ|1", 1003 "b|2", 1004 "B|-1", 1005 "c|3", 1006 "c|4", 1007 "ß|11", 1008 "ss|2", 1009 )}, 1010 } 1011 1012 collationID, _ := collationEnv.LookupID("utf8mb4_0900_ai_ci") 1013 oa := &OrderedAggregate{ 1014 Aggregates: []*AggregateParams{{ 1015 Opcode: AggregateSum, 1016 Col: 1, 1017 }}, 1018 GroupByKeys: []*GroupByParams{{KeyCol: 0, CollationID: collationID}}, 1019 Input: fp, 1020 Collations: map[int]collations.ID{0: collationID}, 1021 } 1022 1023 result, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, false) 1024 assert.NoError(err) 1025 1026 wantResult := sqltypes.MakeTestResult( 1027 fields, 1028 "a|3", 1029 "b|1", 1030 "c|7", 1031 "ß|13", 1032 ) 1033 assert.Equal(wantResult, result) 1034 } 1035 1036 func TestOrderedAggregateCollateAS(t *testing.T) { 1037 assert := assert.New(t) 1038 fields := sqltypes.MakeTestFields( 1039 "col|count(*)", 1040 "varchar|decimal", 1041 ) 1042 fp := &fakePrimitive{ 1043 results: []*sqltypes.Result{sqltypes.MakeTestResult( 1044 fields, 1045 "a|1", 1046 "A|1", 1047 "Ǎ|1", 1048 "b|2", 1049 "c|3", 1050 "c|4", 1051 "Ç|4", 1052 )}, 1053 } 1054 1055 collationID, _ := collationEnv.LookupID("utf8mb4_0900_as_ci") 1056 oa := &OrderedAggregate{ 1057 Aggregates: []*AggregateParams{{ 1058 Opcode: AggregateSum, 1059 Col: 1, 1060 }}, 1061 GroupByKeys: []*GroupByParams{{KeyCol: 0, CollationID: collationID}}, 1062 Collations: map[int]collations.ID{0: collationID}, 1063 Input: fp, 1064 } 1065 1066 result, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, false) 1067 assert.NoError(err) 1068 1069 wantResult := sqltypes.MakeTestResult( 1070 fields, 1071 "a|2", 1072 "Ǎ|1", 1073 "b|2", 1074 "c|7", 1075 "Ç|4", 1076 ) 1077 assert.Equal(wantResult, result) 1078 } 1079 1080 func TestOrderedAggregateCollateKS(t *testing.T) { 1081 assert := assert.New(t) 1082 fields := sqltypes.MakeTestFields( 1083 "col|count(*)", 1084 "varchar|decimal", 1085 ) 1086 fp := &fakePrimitive{ 1087 results: []*sqltypes.Result{sqltypes.MakeTestResult( 1088 fields, 1089 "a|1", 1090 "A|1", 1091 "Ǎ|1", 1092 "b|2", 1093 "c|3", 1094 "c|4", 1095 "\xE3\x83\x8F\xE3\x81\xAF|2", 1096 "\xE3\x83\x8F\xE3\x83\x8F|1", 1097 )}, 1098 } 1099 1100 collationID, _ := collationEnv.LookupID("utf8mb4_ja_0900_as_cs_ks") 1101 oa := &OrderedAggregate{ 1102 Aggregates: []*AggregateParams{{ 1103 Opcode: AggregateSum, 1104 Col: 1, 1105 }}, 1106 GroupByKeys: []*GroupByParams{{KeyCol: 0, CollationID: collationID}}, 1107 Collations: map[int]collations.ID{0: collationID}, 1108 Input: fp, 1109 } 1110 1111 result, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, false) 1112 assert.NoError(err) 1113 1114 wantResult := sqltypes.MakeTestResult( 1115 fields, 1116 "a|1", 1117 "A|1", 1118 "Ǎ|1", 1119 "b|2", 1120 "c|7", 1121 "\xE3\x83\x8F\xE3\x81\xAF|2", 1122 "\xE3\x83\x8F\xE3\x83\x8F|1", 1123 ) 1124 assert.Equal(wantResult, result) 1125 }