vitess.io/vitess@v0.16.2/go/vt/vtgate/engine/merge_sort_test.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package engine
    18  
    19  import (
    20  	"context"
    21  	"errors"
    22  	"testing"
    23  
    24  	"vitess.io/vitess/go/mysql/collations"
    25  	"vitess.io/vitess/go/test/utils"
    26  
    27  	"github.com/stretchr/testify/require"
    28  
    29  	"vitess.io/vitess/go/sqltypes"
    30  
    31  	querypb "vitess.io/vitess/go/vt/proto/query"
    32  )
    33  
    34  // TestMergeSortNormal tests the normal flow of a merge
    35  // sort where all shards return ascending rows.
    36  func TestMergeSortNormal(t *testing.T) {
    37  	idColFields := sqltypes.MakeTestFields("id|col", "int32|varchar")
    38  	shardResults := []*shardResult{{
    39  		results: sqltypes.MakeTestStreamingResults(idColFields,
    40  			"1|a",
    41  			"7|g",
    42  		),
    43  	}, {
    44  		results: sqltypes.MakeTestStreamingResults(idColFields,
    45  			"2|b",
    46  			"---",
    47  			"3|c",
    48  		),
    49  	}, {
    50  		results: sqltypes.MakeTestStreamingResults(idColFields,
    51  			"4|d",
    52  			"6|f",
    53  		),
    54  	}, {
    55  		results: sqltypes.MakeTestStreamingResults(idColFields,
    56  			"4|d",
    57  			"---",
    58  			"8|h",
    59  		),
    60  	}}
    61  	orderBy := []OrderByParams{{
    62  		WeightStringCol: -1,
    63  		Col:             0,
    64  	}}
    65  
    66  	var results []*sqltypes.Result
    67  	err := testMergeSort(shardResults, orderBy, func(qr *sqltypes.Result) error {
    68  		results = append(results, qr)
    69  		return nil
    70  	})
    71  	require.NoError(t, err)
    72  
    73  	// Results are returned one row at a time.
    74  	wantResults := sqltypes.MakeTestStreamingResults(idColFields,
    75  		"1|a",
    76  		"---",
    77  		"2|b",
    78  		"---",
    79  		"3|c",
    80  		"---",
    81  		"4|d",
    82  		"---",
    83  		"4|d",
    84  		"---",
    85  		"6|f",
    86  		"---",
    87  		"7|g",
    88  		"---",
    89  		"8|h",
    90  	)
    91  	utils.MustMatch(t, wantResults, results)
    92  }
    93  
    94  func TestMergeSortWeightString(t *testing.T) {
    95  	idColFields := sqltypes.MakeTestFields("id|col", "varbinary|varchar")
    96  	shardResults := []*shardResult{{
    97  		results: sqltypes.MakeTestStreamingResults(idColFields,
    98  			"1|a",
    99  			"7|g",
   100  		),
   101  	}, {
   102  		results: sqltypes.MakeTestStreamingResults(idColFields,
   103  			"2|b",
   104  			"---",
   105  			"3|c",
   106  		),
   107  	}, {
   108  		results: sqltypes.MakeTestStreamingResults(idColFields,
   109  			"4|d",
   110  			"6|f",
   111  		),
   112  	}, {
   113  		results: sqltypes.MakeTestStreamingResults(idColFields,
   114  			"4|d",
   115  			"---",
   116  			"8|h",
   117  		),
   118  	}}
   119  	orderBy := []OrderByParams{{
   120  		WeightStringCol: 0,
   121  		Col:             1,
   122  	}}
   123  
   124  	var results []*sqltypes.Result
   125  	err := testMergeSort(shardResults, orderBy, func(qr *sqltypes.Result) error {
   126  		results = append(results, qr)
   127  		return nil
   128  	})
   129  	require.NoError(t, err)
   130  
   131  	// Results are returned one row at a time.
   132  	wantResults := sqltypes.MakeTestStreamingResults(idColFields,
   133  		"1|a",
   134  		"---",
   135  		"2|b",
   136  		"---",
   137  		"3|c",
   138  		"---",
   139  		"4|d",
   140  		"---",
   141  		"4|d",
   142  		"---",
   143  		"6|f",
   144  		"---",
   145  		"7|g",
   146  		"---",
   147  		"8|h",
   148  	)
   149  	utils.MustMatch(t, wantResults, results)
   150  }
   151  
   152  func TestMergeSortCollation(t *testing.T) {
   153  	idColFields := sqltypes.MakeTestFields("normal", "varchar")
   154  	shardResults := []*shardResult{{
   155  		results: sqltypes.MakeTestStreamingResults(idColFields,
   156  			"c",
   157  			"---",
   158  			"d",
   159  		),
   160  	}, {
   161  		results: sqltypes.MakeTestStreamingResults(idColFields,
   162  			"cs",
   163  			"---",
   164  			"d",
   165  		),
   166  	}, {
   167  		results: sqltypes.MakeTestStreamingResults(idColFields,
   168  			"cs",
   169  			"---",
   170  			"lu",
   171  		),
   172  	}, {
   173  		results: sqltypes.MakeTestStreamingResults(idColFields,
   174  			"a",
   175  			"---",
   176  			"c",
   177  		),
   178  	}}
   179  
   180  	collationID, _ := collations.Local().LookupID("utf8mb4_hu_0900_ai_ci")
   181  	orderBy := []OrderByParams{{
   182  		Col:         0,
   183  		CollationID: collationID,
   184  	}}
   185  
   186  	var results []*sqltypes.Result
   187  	err := testMergeSort(shardResults, orderBy, func(qr *sqltypes.Result) error {
   188  		results = append(results, qr)
   189  		return nil
   190  	})
   191  	require.NoError(t, err)
   192  
   193  	// Results are returned one row at a time.
   194  	wantResults := sqltypes.MakeTestStreamingResults(idColFields,
   195  		"a",
   196  		"---",
   197  		"c",
   198  		"---",
   199  		"c",
   200  		"---",
   201  		"cs",
   202  		"---",
   203  		"cs",
   204  		"---",
   205  		"d",
   206  		"---",
   207  		"d",
   208  		"---",
   209  		"lu",
   210  	)
   211  	utils.MustMatch(t, wantResults, results)
   212  }
   213  
   214  // TestMergeSortDescending tests the normal flow of a merge
   215  // sort where all shards return descending rows.
   216  func TestMergeSortDescending(t *testing.T) {
   217  	idColFields := sqltypes.MakeTestFields("id|col", "int32|varchar")
   218  	shardResults := []*shardResult{{
   219  		results: sqltypes.MakeTestStreamingResults(idColFields,
   220  			"7|g",
   221  			"1|a",
   222  		),
   223  	}, {
   224  		results: sqltypes.MakeTestStreamingResults(idColFields,
   225  			"3|c",
   226  			"---",
   227  			"2|b",
   228  		),
   229  	}, {
   230  		results: sqltypes.MakeTestStreamingResults(idColFields,
   231  			"6|f",
   232  			"4|d",
   233  		),
   234  	}, {
   235  		results: sqltypes.MakeTestStreamingResults(idColFields,
   236  			"8|h",
   237  			"---",
   238  			"4|d",
   239  		),
   240  	}}
   241  	orderBy := []OrderByParams{{
   242  		WeightStringCol: -1,
   243  		Col:             0,
   244  		Desc:            true,
   245  	}}
   246  
   247  	var results []*sqltypes.Result
   248  	err := testMergeSort(shardResults, orderBy, func(qr *sqltypes.Result) error {
   249  		results = append(results, qr)
   250  		return nil
   251  	})
   252  	require.NoError(t, err)
   253  
   254  	// Results are returned one row at a time.
   255  	wantResults := sqltypes.MakeTestStreamingResults(idColFields,
   256  		"8|h",
   257  		"---",
   258  		"7|g",
   259  		"---",
   260  		"6|f",
   261  		"---",
   262  		"4|d",
   263  		"---",
   264  		"4|d",
   265  		"---",
   266  		"3|c",
   267  		"---",
   268  		"2|b",
   269  		"---",
   270  		"1|a",
   271  	)
   272  	utils.MustMatch(t, wantResults, results)
   273  }
   274  
   275  func TestMergeSortEmptyResults(t *testing.T) {
   276  	idColFields := sqltypes.MakeTestFields("id|col", "int32|varchar")
   277  	shardResults := []*shardResult{{
   278  		results: sqltypes.MakeTestStreamingResults(idColFields,
   279  			"1|a",
   280  			"7|g",
   281  		),
   282  	}, {
   283  		results: sqltypes.MakeTestStreamingResults(idColFields),
   284  	}, {
   285  		results: sqltypes.MakeTestStreamingResults(idColFields,
   286  			"4|d",
   287  			"6|f",
   288  		),
   289  	}, {
   290  		results: sqltypes.MakeTestStreamingResults(idColFields),
   291  	}}
   292  	orderBy := []OrderByParams{{
   293  		WeightStringCol: -1,
   294  		Col:             0,
   295  	}}
   296  
   297  	var results []*sqltypes.Result
   298  	err := testMergeSort(shardResults, orderBy, func(qr *sqltypes.Result) error {
   299  		results = append(results, qr)
   300  		return nil
   301  	})
   302  	require.NoError(t, err)
   303  
   304  	// Results are returned one row at a time.
   305  	wantResults := sqltypes.MakeTestStreamingResults(idColFields,
   306  		"1|a",
   307  		"---",
   308  		"4|d",
   309  		"---",
   310  		"6|f",
   311  		"---",
   312  		"7|g",
   313  	)
   314  	utils.MustMatch(t, wantResults, results)
   315  }
   316  
   317  // TestMergeSortResultFailures tests failures at various
   318  // stages of result return.
   319  func TestMergeSortResultFailures(t *testing.T) {
   320  	orderBy := []OrderByParams{{
   321  		WeightStringCol: -1,
   322  		Col:             0,
   323  	}}
   324  
   325  	// Test early error.
   326  	shardResults := []*shardResult{{
   327  		sendErr: errors.New("early error"),
   328  	}}
   329  	err := testMergeSort(shardResults, orderBy, func(qr *sqltypes.Result) error { return nil })
   330  	want := "early error"
   331  	require.EqualError(t, err, want)
   332  
   333  	// Test fail after fields.
   334  	idFields := sqltypes.MakeTestFields("id", "int32")
   335  	shardResults = []*shardResult{{
   336  		results: sqltypes.MakeTestStreamingResults(idFields),
   337  		sendErr: errors.New("fail after fields"),
   338  	}}
   339  	err = testMergeSort(shardResults, orderBy, func(qr *sqltypes.Result) error { return nil })
   340  	want = "fail after fields"
   341  	require.EqualError(t, err, want)
   342  
   343  	// Test fail after first row.
   344  	shardResults = []*shardResult{{
   345  		results: sqltypes.MakeTestStreamingResults(idFields, "1"),
   346  		sendErr: errors.New("fail after first row"),
   347  	}}
   348  	err = testMergeSort(shardResults, orderBy, func(qr *sqltypes.Result) error { return nil })
   349  	want = "fail after first row"
   350  	require.EqualError(t, err, want)
   351  }
   352  
   353  func TestMergeSortDataFailures(t *testing.T) {
   354  	// The first row being bad fails in a different code path than
   355  	// the case of subsequent rows. So, test the two cases separately.
   356  	idColFields := sqltypes.MakeTestFields("id|col", "int32|varchar")
   357  	shardResults := []*shardResult{{
   358  		results: sqltypes.MakeTestStreamingResults(idColFields,
   359  			"1|a",
   360  		),
   361  	}, {
   362  		results: sqltypes.MakeTestStreamingResults(idColFields,
   363  			"2.1|b",
   364  		),
   365  	}}
   366  	orderBy := []OrderByParams{{
   367  		WeightStringCol: -1,
   368  		Col:             0,
   369  	}}
   370  
   371  	err := testMergeSort(shardResults, orderBy, func(qr *sqltypes.Result) error { return nil })
   372  	want := `strconv.ParseInt: parsing "2.1": invalid syntax`
   373  	require.EqualError(t, err, want)
   374  
   375  	// Create a new VCursor because the previous MergeSort will still
   376  	// have lingering goroutines that can cause data race.
   377  	shardResults = []*shardResult{{
   378  		results: sqltypes.MakeTestStreamingResults(idColFields,
   379  			"1|a",
   380  			"1.1|a",
   381  		),
   382  	}, {
   383  		results: sqltypes.MakeTestStreamingResults(idColFields,
   384  			"2|b",
   385  		),
   386  	}}
   387  	err = testMergeSort(shardResults, orderBy, func(qr *sqltypes.Result) error { return nil })
   388  	want = `strconv.ParseInt: parsing "1.1": invalid syntax`
   389  	require.EqualError(t, err, want)
   390  }
   391  
   392  func testMergeSort(shardResults []*shardResult, orderBy []OrderByParams, callback func(qr *sqltypes.Result) error) error {
   393  	prims := make([]StreamExecutor, 0, len(shardResults))
   394  	for _, sr := range shardResults {
   395  		prims = append(prims, sr)
   396  	}
   397  	ms := MergeSort{
   398  		Primitives: prims,
   399  		OrderBy:    orderBy,
   400  	}
   401  	return ms.TryStreamExecute(context.Background(), &noopVCursor{}, nil, true, callback)
   402  }
   403  
   404  type shardResult struct {
   405  	// shardRoute helps us avoid redefining the Primitive functions.
   406  	shardRoute
   407  
   408  	results []*sqltypes.Result
   409  	sendErr error
   410  }
   411  
   412  func (sr *shardResult) StreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
   413  	for _, r := range sr.results {
   414  		if err := callback(r); err != nil {
   415  			return err
   416  		}
   417  	}
   418  	return sr.sendErr
   419  }