vitess.io/vitess@v0.16.2/go/vt/vtgate/engine/ordered_aggregate_test.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package engine
    18  
    19  import (
    20  	"context"
    21  	"errors"
    22  	"fmt"
    23  	"testing"
    24  
    25  	"github.com/stretchr/testify/assert"
    26  	"github.com/stretchr/testify/require"
    27  
    28  	"vitess.io/vitess/go/mysql/collations"
    29  	"vitess.io/vitess/go/sqltypes"
    30  	"vitess.io/vitess/go/test/utils"
    31  
    32  	binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata"
    33  	querypb "vitess.io/vitess/go/vt/proto/query"
    34  	"vitess.io/vitess/go/vt/servenv"
    35  )
    36  
    37  var collationEnv *collations.Environment
    38  
    39  func init() {
    40  	// We require MySQL 8.0 collations for the comparisons in the tests
    41  	mySQLVersion := "8.0.0"
    42  	servenv.SetMySQLServerVersionForTest(mySQLVersion)
    43  	collationEnv = collations.NewEnvironment(mySQLVersion)
    44  }
    45  
    46  func TestOrderedAggregateExecute(t *testing.T) {
    47  	assert := assert.New(t)
    48  	fields := sqltypes.MakeTestFields(
    49  		"col|count(*)",
    50  		"varbinary|decimal",
    51  	)
    52  	fp := &fakePrimitive{
    53  		results: []*sqltypes.Result{sqltypes.MakeTestResult(
    54  			fields,
    55  			"a|1",
    56  			"a|1",
    57  			"b|2",
    58  			"c|3",
    59  			"c|4",
    60  		)},
    61  	}
    62  
    63  	oa := &OrderedAggregate{
    64  		Aggregates: []*AggregateParams{{
    65  			Opcode: AggregateSum,
    66  			Col:    1,
    67  		}},
    68  		GroupByKeys: []*GroupByParams{{KeyCol: 0}},
    69  		Input:       fp,
    70  	}
    71  
    72  	result, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, false)
    73  	assert.NoError(err)
    74  
    75  	wantResult := sqltypes.MakeTestResult(
    76  		fields,
    77  		"a|2",
    78  		"b|2",
    79  		"c|7",
    80  	)
    81  	assert.Equal(wantResult, result)
    82  }
    83  
    84  func TestOrderedAggregateExecuteTruncate(t *testing.T) {
    85  	assert := assert.New(t)
    86  	fp := &fakePrimitive{
    87  		results: []*sqltypes.Result{sqltypes.MakeTestResult(
    88  			sqltypes.MakeTestFields(
    89  				"col|count(*)|weight_string(col)",
    90  				"varchar|decimal|varbinary",
    91  			),
    92  			"a|1|A",
    93  			"A|1|A",
    94  			"b|2|B",
    95  			"C|3|C",
    96  			"c|4|C",
    97  		)},
    98  	}
    99  
   100  	oa := &OrderedAggregate{
   101  		Aggregates: []*AggregateParams{{
   102  			Opcode: AggregateSum,
   103  			Col:    1,
   104  		}},
   105  		GroupByKeys:         []*GroupByParams{{KeyCol: 2}},
   106  		TruncateColumnCount: 2,
   107  		Input:               fp,
   108  	}
   109  
   110  	result, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, false)
   111  	assert.NoError(err)
   112  
   113  	wantResult := sqltypes.MakeTestResult(
   114  		sqltypes.MakeTestFields(
   115  			"col|count(*)",
   116  			"varchar|decimal",
   117  		),
   118  		"a|2",
   119  		"b|2",
   120  		"C|7",
   121  	)
   122  	assert.Equal(wantResult, result)
   123  }
   124  
   125  func TestOrderedAggregateStreamExecute(t *testing.T) {
   126  	assert := assert.New(t)
   127  	fields := sqltypes.MakeTestFields(
   128  		"col|count(*)",
   129  		"varbinary|decimal",
   130  	)
   131  	fp := &fakePrimitive{
   132  		results: []*sqltypes.Result{sqltypes.MakeTestResult(
   133  			fields,
   134  			"a|1",
   135  			"a|1",
   136  			"b|2",
   137  			"c|3",
   138  			"c|4",
   139  		)},
   140  	}
   141  
   142  	oa := &OrderedAggregate{
   143  		Aggregates: []*AggregateParams{{
   144  			Opcode: AggregateSum,
   145  			Col:    1,
   146  		}},
   147  		GroupByKeys: []*GroupByParams{{KeyCol: 0}},
   148  		Input:       fp,
   149  	}
   150  
   151  	var results []*sqltypes.Result
   152  	err := oa.TryStreamExecute(context.Background(), &noopVCursor{}, nil, true, func(qr *sqltypes.Result) error {
   153  		results = append(results, qr)
   154  		return nil
   155  	})
   156  	assert.NoError(err)
   157  
   158  	wantResults := sqltypes.MakeTestStreamingResults(
   159  		fields,
   160  		"a|2",
   161  		"---",
   162  		"b|2",
   163  		"---",
   164  		"c|7",
   165  	)
   166  	assert.Equal(wantResults, results)
   167  }
   168  
   169  func TestOrderedAggregateStreamExecuteTruncate(t *testing.T) {
   170  	assert := assert.New(t)
   171  	fp := &fakePrimitive{
   172  		results: []*sqltypes.Result{sqltypes.MakeTestResult(
   173  			sqltypes.MakeTestFields(
   174  				"col|count(*)|weight_string(col)",
   175  				"varchar|decimal|varbinary",
   176  			),
   177  			"a|1|A",
   178  			"A|1|A",
   179  			"b|2|B",
   180  			"C|3|C",
   181  			"c|4|C",
   182  		)},
   183  	}
   184  
   185  	oa := &OrderedAggregate{
   186  		Aggregates: []*AggregateParams{{
   187  			Opcode: AggregateSum,
   188  			Col:    1,
   189  		}},
   190  		GroupByKeys:         []*GroupByParams{{KeyCol: 2}},
   191  		TruncateColumnCount: 2,
   192  		Input:               fp,
   193  	}
   194  
   195  	var results []*sqltypes.Result
   196  	err := oa.TryStreamExecute(context.Background(), &noopVCursor{}, nil, true, func(qr *sqltypes.Result) error {
   197  		results = append(results, qr)
   198  		return nil
   199  	})
   200  	assert.NoError(err)
   201  
   202  	wantResults := sqltypes.MakeTestStreamingResults(
   203  		sqltypes.MakeTestFields(
   204  			"col|count(*)",
   205  			"varchar|decimal",
   206  		),
   207  		"a|2",
   208  		"---",
   209  		"b|2",
   210  		"---",
   211  		"C|7",
   212  	)
   213  	assert.Equal(wantResults, results)
   214  }
   215  
   216  func TestOrderedAggregateGetFields(t *testing.T) {
   217  	assert := assert.New(t)
   218  	input := sqltypes.MakeTestResult(
   219  		sqltypes.MakeTestFields(
   220  			"col|count(*)",
   221  			"varbinary|decimal",
   222  		),
   223  	)
   224  	fp := &fakePrimitive{results: []*sqltypes.Result{input}}
   225  
   226  	oa := &OrderedAggregate{Input: fp}
   227  
   228  	got, err := oa.GetFields(context.Background(), nil, nil)
   229  	assert.NoError(err)
   230  	assert.Equal(got, input)
   231  }
   232  
   233  func TestOrderedAggregateGetFieldsTruncate(t *testing.T) {
   234  	assert := assert.New(t)
   235  	result := sqltypes.MakeTestResult(
   236  		sqltypes.MakeTestFields(
   237  			"col|count(*)|weight_string(col)",
   238  			"varchar|decimal|varbinary",
   239  		),
   240  	)
   241  	fp := &fakePrimitive{results: []*sqltypes.Result{result}}
   242  
   243  	oa := &OrderedAggregate{
   244  		TruncateColumnCount: 2,
   245  		Input:               fp,
   246  	}
   247  
   248  	got, err := oa.GetFields(context.Background(), nil, nil)
   249  	assert.NoError(err)
   250  	wantResult := sqltypes.MakeTestResult(
   251  		sqltypes.MakeTestFields(
   252  			"col|count(*)",
   253  			"varchar|decimal",
   254  		),
   255  	)
   256  	assert.Equal(wantResult, got)
   257  }
   258  
   259  func TestOrderedAggregateInputFail(t *testing.T) {
   260  	fp := &fakePrimitive{sendErr: errors.New("input fail")}
   261  
   262  	oa := &OrderedAggregate{Input: fp}
   263  
   264  	want := "input fail"
   265  	if _, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, false); err == nil || err.Error() != want {
   266  		t.Errorf("oa.Execute(): %v, want %s", err, want)
   267  	}
   268  
   269  	fp.rewind()
   270  	if err := oa.TryStreamExecute(context.Background(), &noopVCursor{}, nil, false, func(_ *sqltypes.Result) error { return nil }); err == nil || err.Error() != want {
   271  		t.Errorf("oa.StreamExecute(): %v, want %s", err, want)
   272  	}
   273  
   274  	fp.rewind()
   275  	if _, err := oa.GetFields(context.Background(), nil, nil); err == nil || err.Error() != want {
   276  		t.Errorf("oa.GetFields(): %v, want %s", err, want)
   277  	}
   278  }
   279  
   280  func TestOrderedAggregateExecuteCountDistinct(t *testing.T) {
   281  	assert := assert.New(t)
   282  	fp := &fakePrimitive{
   283  		results: []*sqltypes.Result{sqltypes.MakeTestResult(
   284  			sqltypes.MakeTestFields(
   285  				"col1|col2|count(*)",
   286  				"varbinary|decimal|int64",
   287  			),
   288  			// Two identical values
   289  			"a|1|1",
   290  			"a|1|2",
   291  			// Single value
   292  			"b|1|1",
   293  			// Two different values
   294  			"c|3|1",
   295  			"c|4|1",
   296  			// Single null
   297  			"d|null|1",
   298  			// Start with null
   299  			"e|null|1",
   300  			"e|1|1",
   301  			// Null comes after first
   302  			"f|1|1",
   303  			"f|null|1",
   304  			// Identical to non-identical transition
   305  			"g|1|1",
   306  			"g|1|1",
   307  			"g|2|1",
   308  			"g|3|1",
   309  			// Non-identical to identical transition
   310  			"h|1|1",
   311  			"h|2|1",
   312  			"h|2|1",
   313  			"h|3|1",
   314  			// Key transition, should still count 3
   315  			"i|3|1",
   316  			"i|4|1",
   317  		)},
   318  	}
   319  
   320  	oa := &OrderedAggregate{
   321  		PreProcess: true,
   322  		Aggregates: []*AggregateParams{{
   323  			Opcode: AggregateCountDistinct,
   324  			Col:    1,
   325  			Alias:  "count(distinct col2)",
   326  		}, {
   327  			// Also add a count(*)
   328  			Opcode: AggregateSum,
   329  			Col:    2,
   330  		}},
   331  		GroupByKeys: []*GroupByParams{{KeyCol: 0}},
   332  		Input:       fp,
   333  	}
   334  
   335  	result, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, false)
   336  	assert.NoError(err)
   337  
   338  	wantResult := sqltypes.MakeTestResult(
   339  		sqltypes.MakeTestFields(
   340  			"col1|count(distinct col2)|count(*)",
   341  			"varbinary|int64|int64",
   342  		),
   343  		"a|1|3",
   344  		"b|1|1",
   345  		"c|2|2",
   346  		"d|0|1",
   347  		"e|1|2",
   348  		"f|1|2",
   349  		"g|3|4",
   350  		"h|3|4",
   351  		"i|2|2",
   352  	)
   353  	assert.Equal(wantResult, result)
   354  }
   355  
   356  func TestOrderedAggregateStreamCountDistinct(t *testing.T) {
   357  	assert := assert.New(t)
   358  	fp := &fakePrimitive{
   359  		results: []*sqltypes.Result{sqltypes.MakeTestResult(
   360  			sqltypes.MakeTestFields(
   361  				"col1|col2|count(*)",
   362  				"varbinary|decimal|int64",
   363  			),
   364  			// Two identical values
   365  			"a|1|1",
   366  			"a|1|2",
   367  			// Single value
   368  			"b|1|1",
   369  			// Two different values
   370  			"c|3|1",
   371  			"c|4|1",
   372  			// Single null
   373  			"d|null|1",
   374  			// Start with null
   375  			"e|null|1",
   376  			"e|1|1",
   377  			// Null comes after first
   378  			"f|1|1",
   379  			"f|null|1",
   380  			// Identical to non-identical transition
   381  			"g|1|1",
   382  			"g|1|1",
   383  			"g|2|1",
   384  			"g|3|1",
   385  			// Non-identical to identical transition
   386  			"h|1|1",
   387  			"h|2|1",
   388  			"h|2|1",
   389  			"h|3|1",
   390  			// Key transition, should still count 3
   391  			"i|3|1",
   392  			"i|4|1",
   393  		)},
   394  	}
   395  
   396  	oa := &OrderedAggregate{
   397  		PreProcess: true,
   398  		Aggregates: []*AggregateParams{{
   399  			Opcode: AggregateCountDistinct,
   400  			Col:    1,
   401  			Alias:  "count(distinct col2)",
   402  		}, {
   403  			// Also add a count(*)
   404  			Opcode: AggregateSum,
   405  			Col:    2,
   406  		}},
   407  		GroupByKeys: []*GroupByParams{{KeyCol: 0}},
   408  		Input:       fp,
   409  	}
   410  
   411  	var results []*sqltypes.Result
   412  	err := oa.TryStreamExecute(context.Background(), &noopVCursor{}, nil, true, func(qr *sqltypes.Result) error {
   413  		results = append(results, qr)
   414  		return nil
   415  	})
   416  	assert.NoError(err)
   417  
   418  	wantResults := sqltypes.MakeTestStreamingResults(
   419  		sqltypes.MakeTestFields(
   420  			"col1|count(distinct col2)|count(*)",
   421  			"varbinary|int64|int64",
   422  		),
   423  		"a|1|3",
   424  		"-----",
   425  		"b|1|1",
   426  		"-----",
   427  		"c|2|2",
   428  		"-----",
   429  		"d|0|1",
   430  		"-----",
   431  		"e|1|2",
   432  		"-----",
   433  		"f|1|2",
   434  		"-----",
   435  		"g|3|4",
   436  		"-----",
   437  		"h|3|4",
   438  		"-----",
   439  		"i|2|2",
   440  	)
   441  	assert.Equal(wantResults, results)
   442  }
   443  
   444  func TestOrderedAggregateSumDistinctGood(t *testing.T) {
   445  	assert := assert.New(t)
   446  	fp := &fakePrimitive{
   447  		results: []*sqltypes.Result{sqltypes.MakeTestResult(
   448  			sqltypes.MakeTestFields(
   449  				"col1|col2|sum(col3)",
   450  				"varbinary|int64|decimal",
   451  			),
   452  			// Two identical values
   453  			"a|1|1",
   454  			"a|1|2",
   455  			// Single value
   456  			"b|1|1",
   457  			// Two different values
   458  			"c|3|1",
   459  			"c|4|1",
   460  			// Single null
   461  			"d|null|1",
   462  			"d|1|1",
   463  			// Start with null
   464  			"e|null|1",
   465  			"e|1|1",
   466  			// Null comes after first
   467  			"f|1|1",
   468  			"f|null|1",
   469  			// Identical to non-identical transition
   470  			"g|1|1",
   471  			"g|1|1",
   472  			"g|2|1",
   473  			"g|3|1",
   474  			// Non-identical to identical transition
   475  			"h|1|1",
   476  			"h|2|1",
   477  			"h|2|1",
   478  			"h|3|1",
   479  			// Key transition, should still count 3
   480  			"i|3|1",
   481  			"i|4|1",
   482  		)},
   483  	}
   484  
   485  	oa := &OrderedAggregate{
   486  		PreProcess: true,
   487  		Aggregates: []*AggregateParams{{
   488  			Opcode: AggregateSumDistinct,
   489  			Col:    1,
   490  			Alias:  "sum(distinct col2)",
   491  		}, {
   492  			// Also add a count(*)
   493  			Opcode: AggregateSum,
   494  			Col:    2,
   495  		}},
   496  		GroupByKeys: []*GroupByParams{{KeyCol: 0}},
   497  		Input:       fp,
   498  	}
   499  
   500  	result, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, false)
   501  	assert.NoError(err)
   502  
   503  	wantResult := sqltypes.MakeTestResult(
   504  		sqltypes.MakeTestFields(
   505  			"col1|sum(distinct col2)|sum(col3)",
   506  			"varbinary|decimal|decimal",
   507  		),
   508  		"a|1|3",
   509  		"b|1|1",
   510  		"c|7|2",
   511  		"d|1|2",
   512  		"e|1|2",
   513  		"f|1|2",
   514  		"g|6|4",
   515  		"h|6|4",
   516  		"i|7|2",
   517  	)
   518  	want := fmt.Sprintf("%v", wantResult.Rows)
   519  	got := fmt.Sprintf("%v", result.Rows)
   520  	assert.Equal(want, got)
   521  }
   522  
   523  func TestOrderedAggregateSumDistinctTolerateError(t *testing.T) {
   524  	fp := &fakePrimitive{
   525  		results: []*sqltypes.Result{sqltypes.MakeTestResult(
   526  			sqltypes.MakeTestFields(
   527  				"col1|col2",
   528  				"varbinary|varbinary",
   529  			),
   530  			"a|aaa",
   531  			"a|0",
   532  			"a|1",
   533  		)},
   534  	}
   535  
   536  	oa := &OrderedAggregate{
   537  		PreProcess: true,
   538  		Aggregates: []*AggregateParams{{
   539  			Opcode: AggregateSumDistinct,
   540  			Col:    1,
   541  			Alias:  "sum(distinct col2)",
   542  		}},
   543  		GroupByKeys: []*GroupByParams{{KeyCol: 0}},
   544  		Input:       fp,
   545  	}
   546  
   547  	result, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, false)
   548  	assert.NoError(t, err)
   549  
   550  	wantResult := sqltypes.MakeTestResult(
   551  		sqltypes.MakeTestFields(
   552  			"col1|sum(distinct col2)",
   553  			"varbinary|decimal",
   554  		),
   555  		"a|1",
   556  	)
   557  	utils.MustMatch(t, wantResult, result, "")
   558  }
   559  
   560  func TestOrderedAggregateKeysFail(t *testing.T) {
   561  	fields := sqltypes.MakeTestFields(
   562  		"col|count(*)",
   563  		"varchar|decimal",
   564  	)
   565  	fp := &fakePrimitive{
   566  		results: []*sqltypes.Result{sqltypes.MakeTestResult(
   567  			fields,
   568  			"a|1",
   569  			"a|1",
   570  		)},
   571  	}
   572  
   573  	oa := &OrderedAggregate{
   574  		Aggregates: []*AggregateParams{{
   575  			Opcode: AggregateSum,
   576  			Col:    1,
   577  		}},
   578  		GroupByKeys: []*GroupByParams{{KeyCol: 0}},
   579  		Input:       fp,
   580  	}
   581  
   582  	want := "cannot compare strings, collation is unknown or unsupported (collation ID: 0)"
   583  	if _, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, false); err == nil || err.Error() != want {
   584  		t.Errorf("oa.Execute(): %v, want %s", err, want)
   585  	}
   586  
   587  	fp.rewind()
   588  	if err := oa.TryStreamExecute(context.Background(), &noopVCursor{}, nil, false, func(_ *sqltypes.Result) error { return nil }); err == nil || err.Error() != want {
   589  		t.Errorf("oa.StreamExecute(): %v, want %s", err, want)
   590  	}
   591  }
   592  
   593  func TestOrderedAggregateMergeFail(t *testing.T) {
   594  	fields := sqltypes.MakeTestFields(
   595  		"col|count(*)",
   596  		"varbinary|decimal",
   597  	)
   598  	fp := &fakePrimitive{
   599  		results: []*sqltypes.Result{sqltypes.MakeTestResult(
   600  			fields,
   601  			"a|1",
   602  			"a|0",
   603  		)},
   604  	}
   605  
   606  	oa := &OrderedAggregate{
   607  		Aggregates: []*AggregateParams{{
   608  			Opcode: AggregateSum,
   609  			Col:    1,
   610  		}},
   611  		GroupByKeys: []*GroupByParams{{KeyCol: 0}},
   612  		Input:       fp,
   613  	}
   614  
   615  	result := &sqltypes.Result{
   616  		Fields: []*querypb.Field{
   617  			{
   618  				Name: "col",
   619  				Type: querypb.Type_VARBINARY,
   620  			},
   621  			{
   622  				Name: "count(*)",
   623  				Type: querypb.Type_DECIMAL,
   624  			},
   625  		},
   626  		Rows: [][]sqltypes.Value{
   627  			{
   628  				sqltypes.MakeTrusted(querypb.Type_VARBINARY, []byte("a")),
   629  				sqltypes.MakeTrusted(querypb.Type_DECIMAL, []byte("1")),
   630  			},
   631  		},
   632  	}
   633  
   634  	res, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, false)
   635  	require.NoError(t, err)
   636  
   637  	utils.MustMatch(t, result, res, "Found mismatched values")
   638  
   639  	fp.rewind()
   640  	err = oa.TryStreamExecute(context.Background(), &noopVCursor{}, nil, true, func(_ *sqltypes.Result) error { return nil })
   641  	require.NoError(t, err)
   642  }
   643  
   644  func TestMerge(t *testing.T) {
   645  	assert := assert.New(t)
   646  	oa := &OrderedAggregate{
   647  		Aggregates: []*AggregateParams{{
   648  			Opcode: AggregateSum,
   649  			Col:    1,
   650  		}, {
   651  			Opcode: AggregateSum,
   652  			Col:    2,
   653  		}, {
   654  			Opcode: AggregateMin,
   655  			Col:    3,
   656  		}, {
   657  			Opcode: AggregateMax,
   658  			Col:    4,
   659  		}},
   660  	}
   661  	fields := sqltypes.MakeTestFields(
   662  		"a|b|c|d|e",
   663  		"int64|int64|decimal|in32|varbinary",
   664  	)
   665  	r := sqltypes.MakeTestResult(fields,
   666  		"1|2|3.2|3|ab",
   667  		"1|3|2.8|2|bc",
   668  	)
   669  
   670  	merged, _, err := merge(fields, r.Rows[0], r.Rows[1], nil, nil, oa.Aggregates)
   671  	assert.NoError(err)
   672  	want := sqltypes.MakeTestResult(fields, "1|5|6.0|2|bc").Rows[0]
   673  	assert.Equal(want, merged)
   674  
   675  	// swap and retry
   676  	merged, _, err = merge(fields, r.Rows[1], r.Rows[0], nil, nil, oa.Aggregates)
   677  	assert.NoError(err)
   678  	assert.Equal(want, merged)
   679  }
   680  
   681  func TestOrderedAggregateExecuteGtid(t *testing.T) {
   682  	vgtid := binlogdatapb.VGtid{}
   683  	vgtid.ShardGtids = append(vgtid.ShardGtids, &binlogdatapb.ShardGtid{
   684  		Keyspace: "ks",
   685  		Shard:    "-80",
   686  		Gtid:     "a",
   687  	})
   688  	vgtid.ShardGtids = append(vgtid.ShardGtids, &binlogdatapb.ShardGtid{
   689  		Keyspace: "ks",
   690  		Shard:    "80-",
   691  		Gtid:     "b",
   692  	})
   693  
   694  	fp := &fakePrimitive{
   695  		results: []*sqltypes.Result{sqltypes.MakeTestResult(
   696  			sqltypes.MakeTestFields(
   697  				"keyspace|gtid|shard",
   698  				"varchar|varchar|varchar",
   699  			),
   700  			"ks|a|-40",
   701  			"ks|b|40-80",
   702  			"ks|c|80-c0",
   703  			"ks|d|c0-",
   704  		)},
   705  	}
   706  
   707  	oa := &OrderedAggregate{
   708  		PreProcess: true,
   709  		Aggregates: []*AggregateParams{{
   710  			Opcode: AggregateGtid,
   711  			Col:    1,
   712  			Alias:  "vgtid",
   713  		}},
   714  		TruncateColumnCount: 2,
   715  		Input:               fp,
   716  	}
   717  
   718  	result, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, false)
   719  	require.NoError(t, err)
   720  
   721  	wantResult := sqltypes.MakeTestResult(
   722  		sqltypes.MakeTestFields(
   723  			"keyspace|vgtid",
   724  			"varchar|varchar",
   725  		),
   726  		`ks|shard_gtids:{keyspace:"ks" shard:"-40" gtid:"a"} shard_gtids:{keyspace:"ks" shard:"40-80" gtid:"b"} shard_gtids:{keyspace:"ks" shard:"80-c0" gtid:"c"} shard_gtids:{keyspace:"ks" shard:"c0-" gtid:"d"}`,
   727  	)
   728  	assert.Equal(t, wantResult, result)
   729  }
   730  
   731  func TestCountDistinctOnVarchar(t *testing.T) {
   732  	fields := sqltypes.MakeTestFields(
   733  		"c1|c2|weight_string(c2)",
   734  		"int64|varchar|varbinary",
   735  	)
   736  	fp := &fakePrimitive{
   737  		results: []*sqltypes.Result{sqltypes.MakeTestResult(
   738  			fields,
   739  			"10|a|0x41",
   740  			"10|a|0x41",
   741  			"10|b|0x42",
   742  			"20|b|0x42",
   743  		)},
   744  	}
   745  
   746  	oa := &OrderedAggregate{
   747  		PreProcess: true,
   748  		Aggregates: []*AggregateParams{{
   749  			Opcode:    AggregateCountDistinct,
   750  			Col:       1,
   751  			WCol:      2,
   752  			WAssigned: true,
   753  			Alias:     "count(distinct c2)",
   754  		}},
   755  		GroupByKeys:         []*GroupByParams{{KeyCol: 0}},
   756  		Input:               fp,
   757  		TruncateColumnCount: 2,
   758  	}
   759  
   760  	want := sqltypes.MakeTestResult(
   761  		sqltypes.MakeTestFields(
   762  			"c1|count(distinct c2)",
   763  			"int64|int64",
   764  		),
   765  		`10|2`,
   766  		`20|1`,
   767  	)
   768  
   769  	qr, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, false)
   770  	require.NoError(t, err)
   771  	assert.Equal(t, want, qr)
   772  
   773  	fp.rewind()
   774  	results := &sqltypes.Result{}
   775  	err = oa.TryStreamExecute(context.Background(), &noopVCursor{}, nil, true, func(qr *sqltypes.Result) error {
   776  		if qr.Fields != nil {
   777  			results.Fields = qr.Fields
   778  		}
   779  		results.Rows = append(results.Rows, qr.Rows...)
   780  		return nil
   781  	})
   782  	require.NoError(t, err)
   783  	assert.Equal(t, want, results)
   784  }
   785  
   786  func TestCountDistinctOnVarcharWithNulls(t *testing.T) {
   787  	fields := sqltypes.MakeTestFields(
   788  		"c1|c2|weight_string(c2)",
   789  		"int64|varchar|varbinary",
   790  	)
   791  	fp := &fakePrimitive{
   792  		results: []*sqltypes.Result{sqltypes.MakeTestResult(
   793  			fields,
   794  			"null|null|null",
   795  			"null|a|0x41",
   796  			"null|b|0x42",
   797  			"10|null|null",
   798  			"10|null|null",
   799  			"10|a|0x41",
   800  			"10|a|0x41",
   801  			"10|b|0x42",
   802  			"20|null|null",
   803  			"20|b|0x42",
   804  			"30|null|null",
   805  			"30|null|null",
   806  			"30|null|null",
   807  			"30|null|null",
   808  		)},
   809  	}
   810  
   811  	oa := &OrderedAggregate{
   812  		PreProcess: true,
   813  		Aggregates: []*AggregateParams{{
   814  			Opcode:    AggregateCountDistinct,
   815  			Col:       1,
   816  			WCol:      2,
   817  			WAssigned: true,
   818  			Alias:     "count(distinct c2)",
   819  		}},
   820  		GroupByKeys:         []*GroupByParams{{KeyCol: 0}},
   821  		Input:               fp,
   822  		TruncateColumnCount: 2,
   823  	}
   824  
   825  	want := sqltypes.MakeTestResult(
   826  		sqltypes.MakeTestFields(
   827  			"c1|count(distinct c2)",
   828  			"int64|int64",
   829  		),
   830  		`null|2`,
   831  		`10|2`,
   832  		`20|1`,
   833  		`30|0`,
   834  	)
   835  
   836  	qr, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, false)
   837  	require.NoError(t, err)
   838  	assert.Equal(t, want, qr)
   839  
   840  	fp.rewind()
   841  	results := &sqltypes.Result{}
   842  	err = oa.TryStreamExecute(context.Background(), &noopVCursor{}, nil, true, func(qr *sqltypes.Result) error {
   843  		if qr.Fields != nil {
   844  			results.Fields = qr.Fields
   845  		}
   846  		results.Rows = append(results.Rows, qr.Rows...)
   847  		return nil
   848  	})
   849  	require.NoError(t, err)
   850  	assert.Equal(t, want, results)
   851  }
   852  
   853  func TestSumDistinctOnVarcharWithNulls(t *testing.T) {
   854  	fields := sqltypes.MakeTestFields(
   855  		"c1|c2|weight_string(c2)",
   856  		"int64|varchar|varbinary",
   857  	)
   858  	fp := &fakePrimitive{
   859  		results: []*sqltypes.Result{sqltypes.MakeTestResult(
   860  			fields,
   861  			"null|null|null",
   862  			"null|a|0x41",
   863  			"null|b|0x42",
   864  			"10|null|null",
   865  			"10|null|null",
   866  			"10|a|0x41",
   867  			"10|a|0x41",
   868  			"10|b|0x42",
   869  			"20|null|null",
   870  			"20|b|0x42",
   871  			"30|null|null",
   872  			"30|null|null",
   873  			"30|null|null",
   874  			"30|null|null",
   875  		)},
   876  	}
   877  
   878  	oa := &OrderedAggregate{
   879  		PreProcess: true,
   880  		Aggregates: []*AggregateParams{{
   881  			Opcode:    AggregateSumDistinct,
   882  			Col:       1,
   883  			WCol:      2,
   884  			WAssigned: true,
   885  			Alias:     "sum(distinct c2)",
   886  		}},
   887  		GroupByKeys:         []*GroupByParams{{KeyCol: 0}},
   888  		Input:               fp,
   889  		TruncateColumnCount: 2,
   890  	}
   891  
   892  	want := sqltypes.MakeTestResult(
   893  		sqltypes.MakeTestFields(
   894  			"c1|sum(distinct c2)",
   895  			"int64|decimal",
   896  		),
   897  		`null|0`,
   898  		`10|0`,
   899  		`20|0`,
   900  		`30|null`,
   901  	)
   902  
   903  	qr, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, false)
   904  	require.NoError(t, err)
   905  	assert.Equal(t, want, qr)
   906  
   907  	fp.rewind()
   908  	results := &sqltypes.Result{}
   909  	err = oa.TryStreamExecute(context.Background(), &noopVCursor{}, nil, true, func(qr *sqltypes.Result) error {
   910  		if qr.Fields != nil {
   911  			results.Fields = qr.Fields
   912  		}
   913  		results.Rows = append(results.Rows, qr.Rows...)
   914  		return nil
   915  	})
   916  	require.NoError(t, err)
   917  	assert.Equal(t, want, results)
   918  }
   919  
   920  func TestMultiDistinct(t *testing.T) {
   921  	fields := sqltypes.MakeTestFields(
   922  		"c1|c2|c3",
   923  		"int64|int64|int64",
   924  	)
   925  	fp := &fakePrimitive{
   926  		results: []*sqltypes.Result{sqltypes.MakeTestResult(
   927  			fields,
   928  			"null|null|null",
   929  			"null|1|2",
   930  			"null|2|2",
   931  			"10|null|null",
   932  			"10|2|null",
   933  			"10|2|1",
   934  			"10|2|3",
   935  			"10|3|3",
   936  			"20|null|null",
   937  			"20|null|null",
   938  			"30|1|1",
   939  			"30|1|2",
   940  			"30|1|3",
   941  			"40|1|1",
   942  			"40|2|1",
   943  			"40|3|1",
   944  		)},
   945  	}
   946  
   947  	oa := &OrderedAggregate{
   948  		PreProcess: true,
   949  		Aggregates: []*AggregateParams{{
   950  			Opcode: AggregateCountDistinct,
   951  			Col:    1,
   952  			Alias:  "count(distinct c2)",
   953  		}, {
   954  			Opcode: AggregateSumDistinct,
   955  			Col:    2,
   956  			Alias:  "sum(distinct c3)",
   957  		}},
   958  		GroupByKeys: []*GroupByParams{{KeyCol: 0}},
   959  		Input:       fp,
   960  	}
   961  
   962  	want := sqltypes.MakeTestResult(
   963  		sqltypes.MakeTestFields(
   964  			"c1|count(distinct c2)|sum(distinct c3)",
   965  			"int64|int64|decimal",
   966  		),
   967  		`null|2|2`,
   968  		`10|2|4`,
   969  		`20|0|null`,
   970  		`30|1|6`,
   971  		`40|3|1`,
   972  	)
   973  
   974  	qr, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, false)
   975  	require.NoError(t, err)
   976  	assert.Equal(t, want, qr)
   977  
   978  	fp.rewind()
   979  	results := &sqltypes.Result{}
   980  	err = oa.TryStreamExecute(context.Background(), &noopVCursor{}, nil, true, func(qr *sqltypes.Result) error {
   981  		if qr.Fields != nil {
   982  			results.Fields = qr.Fields
   983  		}
   984  		results.Rows = append(results.Rows, qr.Rows...)
   985  		return nil
   986  	})
   987  	require.NoError(t, err)
   988  	assert.Equal(t, want, results)
   989  }
   990  
   991  func TestOrderedAggregateCollate(t *testing.T) {
   992  	assert := assert.New(t)
   993  	fields := sqltypes.MakeTestFields(
   994  		"col|count(*)",
   995  		"varchar|decimal",
   996  	)
   997  	fp := &fakePrimitive{
   998  		results: []*sqltypes.Result{sqltypes.MakeTestResult(
   999  			fields,
  1000  			"a|1",
  1001  			"A|1",
  1002  			"Ǎ|1",
  1003  			"b|2",
  1004  			"B|-1",
  1005  			"c|3",
  1006  			"c|4",
  1007  			"ß|11",
  1008  			"ss|2",
  1009  		)},
  1010  	}
  1011  
  1012  	collationID, _ := collationEnv.LookupID("utf8mb4_0900_ai_ci")
  1013  	oa := &OrderedAggregate{
  1014  		Aggregates: []*AggregateParams{{
  1015  			Opcode: AggregateSum,
  1016  			Col:    1,
  1017  		}},
  1018  		GroupByKeys: []*GroupByParams{{KeyCol: 0, CollationID: collationID}},
  1019  		Input:       fp,
  1020  		Collations:  map[int]collations.ID{0: collationID},
  1021  	}
  1022  
  1023  	result, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, false)
  1024  	assert.NoError(err)
  1025  
  1026  	wantResult := sqltypes.MakeTestResult(
  1027  		fields,
  1028  		"a|3",
  1029  		"b|1",
  1030  		"c|7",
  1031  		"ß|13",
  1032  	)
  1033  	assert.Equal(wantResult, result)
  1034  }
  1035  
  1036  func TestOrderedAggregateCollateAS(t *testing.T) {
  1037  	assert := assert.New(t)
  1038  	fields := sqltypes.MakeTestFields(
  1039  		"col|count(*)",
  1040  		"varchar|decimal",
  1041  	)
  1042  	fp := &fakePrimitive{
  1043  		results: []*sqltypes.Result{sqltypes.MakeTestResult(
  1044  			fields,
  1045  			"a|1",
  1046  			"A|1",
  1047  			"Ǎ|1",
  1048  			"b|2",
  1049  			"c|3",
  1050  			"c|4",
  1051  			"Ç|4",
  1052  		)},
  1053  	}
  1054  
  1055  	collationID, _ := collationEnv.LookupID("utf8mb4_0900_as_ci")
  1056  	oa := &OrderedAggregate{
  1057  		Aggregates: []*AggregateParams{{
  1058  			Opcode: AggregateSum,
  1059  			Col:    1,
  1060  		}},
  1061  		GroupByKeys: []*GroupByParams{{KeyCol: 0, CollationID: collationID}},
  1062  		Collations:  map[int]collations.ID{0: collationID},
  1063  		Input:       fp,
  1064  	}
  1065  
  1066  	result, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, false)
  1067  	assert.NoError(err)
  1068  
  1069  	wantResult := sqltypes.MakeTestResult(
  1070  		fields,
  1071  		"a|2",
  1072  		"Ǎ|1",
  1073  		"b|2",
  1074  		"c|7",
  1075  		"Ç|4",
  1076  	)
  1077  	assert.Equal(wantResult, result)
  1078  }
  1079  
  1080  func TestOrderedAggregateCollateKS(t *testing.T) {
  1081  	assert := assert.New(t)
  1082  	fields := sqltypes.MakeTestFields(
  1083  		"col|count(*)",
  1084  		"varchar|decimal",
  1085  	)
  1086  	fp := &fakePrimitive{
  1087  		results: []*sqltypes.Result{sqltypes.MakeTestResult(
  1088  			fields,
  1089  			"a|1",
  1090  			"A|1",
  1091  			"Ǎ|1",
  1092  			"b|2",
  1093  			"c|3",
  1094  			"c|4",
  1095  			"\xE3\x83\x8F\xE3\x81\xAF|2",
  1096  			"\xE3\x83\x8F\xE3\x83\x8F|1",
  1097  		)},
  1098  	}
  1099  
  1100  	collationID, _ := collationEnv.LookupID("utf8mb4_ja_0900_as_cs_ks")
  1101  	oa := &OrderedAggregate{
  1102  		Aggregates: []*AggregateParams{{
  1103  			Opcode: AggregateSum,
  1104  			Col:    1,
  1105  		}},
  1106  		GroupByKeys: []*GroupByParams{{KeyCol: 0, CollationID: collationID}},
  1107  		Collations:  map[int]collations.ID{0: collationID},
  1108  		Input:       fp,
  1109  	}
  1110  
  1111  	result, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, false)
  1112  	assert.NoError(err)
  1113  
  1114  	wantResult := sqltypes.MakeTestResult(
  1115  		fields,
  1116  		"a|1",
  1117  		"A|1",
  1118  		"Ǎ|1",
  1119  		"b|2",
  1120  		"c|7",
  1121  		"\xE3\x83\x8F\xE3\x81\xAF|2",
  1122  		"\xE3\x83\x8F\xE3\x83\x8F|1",
  1123  	)
  1124  	assert.Equal(wantResult, result)
  1125  }