vitess.io/vitess@v0.16.2/go/vt/vtgate/engine/limit_test.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package engine
    18  
    19  import (
    20  	"context"
    21  	"errors"
    22  	"testing"
    23  
    24  	"github.com/stretchr/testify/assert"
    25  
    26  	"vitess.io/vitess/go/mysql/collations"
    27  	"vitess.io/vitess/go/vt/vtgate/evalengine"
    28  
    29  	"github.com/stretchr/testify/require"
    30  
    31  	"vitess.io/vitess/go/sqltypes"
    32  	querypb "vitess.io/vitess/go/vt/proto/query"
    33  )
    34  
    35  func TestLimitExecute(t *testing.T) {
    36  	bindVars := make(map[string]*querypb.BindVariable)
    37  	fields := sqltypes.MakeTestFields(
    38  		"col1|col2",
    39  		"int64|varchar",
    40  	)
    41  	inputResult := sqltypes.MakeTestResult(
    42  		fields,
    43  		"a|1",
    44  		"b|2",
    45  		"c|3",
    46  	)
    47  	fp := &fakePrimitive{
    48  		results: []*sqltypes.Result{inputResult},
    49  	}
    50  
    51  	l := &Limit{
    52  		Count: evalengine.NewLiteralInt(2),
    53  		Input: fp,
    54  	}
    55  
    56  	// Test with limit smaller than input.
    57  	result, err := l.TryExecute(context.Background(), &noopVCursor{}, bindVars, false)
    58  	require.NoError(t, err)
    59  	wantResult := sqltypes.MakeTestResult(
    60  		fields,
    61  		"a|1",
    62  		"b|2",
    63  	)
    64  	if !result.Equal(wantResult) {
    65  		t.Errorf("l.Execute:\n%v, want\n%v", result, wantResult)
    66  	}
    67  
    68  	// Test with limit equal to input.
    69  	wantResult = sqltypes.MakeTestResult(
    70  		fields,
    71  		"a|1",
    72  		"b|2",
    73  		"c|3",
    74  	)
    75  	inputResult = sqltypes.MakeTestResult(
    76  		fields,
    77  		"a|1",
    78  		"b|2",
    79  		"c|3",
    80  	)
    81  	fp = &fakePrimitive{
    82  		results: []*sqltypes.Result{inputResult},
    83  	}
    84  	l = &Limit{
    85  		Count: evalengine.NewLiteralInt(3),
    86  		Input: fp,
    87  	}
    88  
    89  	result, err = l.TryExecute(context.Background(), &noopVCursor{}, bindVars, false)
    90  	require.NoError(t, err)
    91  	if !result.Equal(inputResult) {
    92  		t.Errorf("l.Execute:\n%v, want\n%v", result, wantResult)
    93  	}
    94  
    95  	// Test with limit higher than input.
    96  	inputResult = sqltypes.MakeTestResult(
    97  		fields,
    98  		"a|1",
    99  		"b|2",
   100  		"c|3",
   101  	)
   102  	fp = &fakePrimitive{
   103  		results: []*sqltypes.Result{inputResult},
   104  	}
   105  	l = &Limit{
   106  		Count: evalengine.NewLiteralInt(4),
   107  		Input: fp,
   108  	}
   109  
   110  	result, err = l.TryExecute(context.Background(), &noopVCursor{}, bindVars, false)
   111  	require.NoError(t, err)
   112  	if !result.Equal(wantResult) {
   113  		t.Errorf("l.Execute:\n%v, want\n%v", result, wantResult)
   114  	}
   115  
   116  	// Test with bind vars.
   117  	wantResult = sqltypes.MakeTestResult(
   118  		fields,
   119  		"a|1",
   120  		"b|2",
   121  	)
   122  	inputResult = sqltypes.MakeTestResult(
   123  		fields,
   124  		"a|1",
   125  		"b|2",
   126  		"c|3",
   127  	)
   128  	fp = &fakePrimitive{
   129  		results: []*sqltypes.Result{inputResult},
   130  	}
   131  	l = &Limit{
   132  		Count: evalengine.NewBindVar("l", collations.TypedCollation{}),
   133  		Input: fp,
   134  	}
   135  
   136  	result, err = l.TryExecute(context.Background(), &noopVCursor{}, map[string]*querypb.BindVariable{"l": sqltypes.Int64BindVariable(2)}, false)
   137  	require.NoError(t, err)
   138  	if !result.Equal(wantResult) {
   139  		t.Errorf("l.Execute:\n%v, want\n%v", result, wantResult)
   140  	}
   141  }
   142  
   143  func TestLimitOffsetExecute(t *testing.T) {
   144  	bindVars := make(map[string]*querypb.BindVariable)
   145  	fields := sqltypes.MakeTestFields(
   146  		"col1|col2",
   147  		"int64|varchar",
   148  	)
   149  	inputResult := sqltypes.MakeTestResult(
   150  		fields,
   151  		"a|1",
   152  		"b|2",
   153  		"c|3",
   154  		"c|4",
   155  		"c|5",
   156  		"c|6",
   157  	)
   158  	fp := &fakePrimitive{
   159  		results: []*sqltypes.Result{inputResult},
   160  	}
   161  
   162  	l := &Limit{
   163  		Count:  evalengine.NewLiteralInt(2),
   164  		Offset: evalengine.NewLiteralInt(0),
   165  		Input:  fp,
   166  	}
   167  
   168  	// Test with offset 0
   169  	result, err := l.TryExecute(context.Background(), &noopVCursor{}, bindVars, false)
   170  	require.NoError(t, err)
   171  	wantResult := sqltypes.MakeTestResult(
   172  		fields,
   173  		"a|1",
   174  		"b|2",
   175  	)
   176  	if !result.Equal(wantResult) {
   177  		t.Errorf("l.Execute:\n%v, want\n%v", result, wantResult)
   178  	}
   179  
   180  	// Test with offset set
   181  
   182  	inputResult = sqltypes.MakeTestResult(
   183  		fields,
   184  		"a|1",
   185  		"b|2",
   186  		"c|3",
   187  		"c|4",
   188  		"c|5",
   189  		"c|6",
   190  	)
   191  	fp = &fakePrimitive{
   192  		results: []*sqltypes.Result{inputResult},
   193  	}
   194  
   195  	l = &Limit{
   196  		Count:  evalengine.NewLiteralInt(2),
   197  		Offset: evalengine.NewLiteralInt(1),
   198  		Input:  fp,
   199  	}
   200  	wantResult = sqltypes.MakeTestResult(
   201  		fields,
   202  		"b|2",
   203  		"c|3",
   204  	)
   205  	result, err = l.TryExecute(context.Background(), &noopVCursor{}, bindVars, false)
   206  	require.NoError(t, err)
   207  	if !result.Equal(wantResult) {
   208  		t.Errorf("l.Execute:\n got %v, want\n%v", result, wantResult)
   209  	}
   210  
   211  	// Works on boundary condition (elements == limit + offset)
   212  	inputResult = sqltypes.MakeTestResult(
   213  		fields,
   214  		"a|1",
   215  		"b|2",
   216  		"c|3",
   217  		"c|4",
   218  		"c|5",
   219  		"c|6",
   220  	)
   221  	fp = &fakePrimitive{
   222  		results: []*sqltypes.Result{inputResult},
   223  	}
   224  
   225  	l = &Limit{
   226  		Count:  evalengine.NewLiteralInt(2),
   227  		Offset: evalengine.NewLiteralInt(4),
   228  		Input:  fp,
   229  	}
   230  	wantResult = sqltypes.MakeTestResult(
   231  		fields,
   232  		"c|5",
   233  		"c|6",
   234  	)
   235  	result, err = l.TryExecute(context.Background(), &noopVCursor{}, bindVars, false)
   236  	require.NoError(t, err)
   237  	if !result.Equal(wantResult) {
   238  		t.Errorf("l.Execute:\n got %v, want\n%v", result, wantResult)
   239  	}
   240  
   241  	inputResult = sqltypes.MakeTestResult(
   242  		fields,
   243  		"a|1",
   244  		"b|2",
   245  		"c|3",
   246  		"c|4",
   247  		"c|5",
   248  		"c|6",
   249  	)
   250  	fp = &fakePrimitive{
   251  		results: []*sqltypes.Result{inputResult},
   252  	}
   253  
   254  	l = &Limit{
   255  		Count:  evalengine.NewLiteralInt(4),
   256  		Offset: evalengine.NewLiteralInt(2),
   257  		Input:  fp,
   258  	}
   259  	wantResult = sqltypes.MakeTestResult(
   260  		fields,
   261  		"c|3",
   262  		"c|4",
   263  		"c|5",
   264  		"c|6",
   265  	)
   266  	result, err = l.TryExecute(context.Background(), &noopVCursor{}, bindVars, false)
   267  	require.NoError(t, err)
   268  	if !result.Equal(wantResult) {
   269  		t.Errorf("l.Execute:\n got %v, want\n%v", result, wantResult)
   270  	}
   271  
   272  	// test when limit is beyond the number of available elements
   273  	inputResult = sqltypes.MakeTestResult(
   274  		fields,
   275  		"a|1",
   276  		"b|2",
   277  		"c|3",
   278  		"c|4",
   279  		"c|5",
   280  		"c|6",
   281  	)
   282  	fp = &fakePrimitive{
   283  		results: []*sqltypes.Result{inputResult},
   284  	}
   285  
   286  	l = &Limit{
   287  		Count:  evalengine.NewLiteralInt(2),
   288  		Offset: evalengine.NewLiteralInt(5),
   289  		Input:  fp,
   290  	}
   291  	wantResult = sqltypes.MakeTestResult(
   292  		fields,
   293  		"c|6",
   294  	)
   295  	result, err = l.TryExecute(context.Background(), &noopVCursor{}, bindVars, false)
   296  	require.NoError(t, err)
   297  	if !result.Equal(wantResult) {
   298  		t.Errorf("l.Execute:\n got %v, want\n%v", result, wantResult)
   299  	}
   300  
   301  	// Works when offset is beyond the response
   302  	inputResult = sqltypes.MakeTestResult(
   303  		fields,
   304  		"a|1",
   305  		"b|2",
   306  		"c|3",
   307  		"c|4",
   308  		"c|5",
   309  		"c|6",
   310  	)
   311  	fp = &fakePrimitive{
   312  		results: []*sqltypes.Result{inputResult},
   313  	}
   314  
   315  	l = &Limit{
   316  		Count:  evalengine.NewLiteralInt(2),
   317  		Offset: evalengine.NewLiteralInt(7),
   318  		Input:  fp,
   319  	}
   320  	wantResult = sqltypes.MakeTestResult(
   321  		fields,
   322  	)
   323  	result, err = l.TryExecute(context.Background(), &noopVCursor{}, bindVars, false)
   324  	require.NoError(t, err)
   325  	if !result.Equal(wantResult) {
   326  		t.Errorf("l.Execute:\n got %v, want\n%v", result, wantResult)
   327  	}
   328  
   329  	// works with bindvars
   330  	inputResult = sqltypes.MakeTestResult(
   331  		fields,
   332  		"x|1",
   333  		"z|2",
   334  	)
   335  	wantResult = sqltypes.MakeTestResult(
   336  		fields,
   337  		"z|2",
   338  	)
   339  
   340  	fp = &fakePrimitive{
   341  		results: []*sqltypes.Result{inputResult},
   342  	}
   343  
   344  	l = &Limit{
   345  		Count:  evalengine.NewBindVar("l", collations.TypedCollation{}),
   346  		Offset: evalengine.NewBindVar("o", collations.TypedCollation{}),
   347  		Input:  fp,
   348  	}
   349  	result, err = l.TryExecute(context.Background(), &noopVCursor{}, map[string]*querypb.BindVariable{"l": sqltypes.Int64BindVariable(1), "o": sqltypes.Int64BindVariable(1)}, false)
   350  	require.NoError(t, err)
   351  	if !result.Equal(wantResult) {
   352  		t.Errorf("l.Execute:\n got %v, want\n%v", result, wantResult)
   353  	}
   354  }
   355  
   356  func TestLimitStreamExecute(t *testing.T) {
   357  	bindVars := make(map[string]*querypb.BindVariable)
   358  	fields := sqltypes.MakeTestFields(
   359  		"col1|col2",
   360  		"int64|varchar",
   361  	)
   362  	inputResult := sqltypes.MakeTestResult(
   363  		fields,
   364  		"a|1",
   365  		"b|2",
   366  		"c|3",
   367  	)
   368  	fp := &fakePrimitive{
   369  		results: []*sqltypes.Result{inputResult},
   370  	}
   371  
   372  	l := &Limit{
   373  		Count: evalengine.NewLiteralInt(2),
   374  		Input: fp,
   375  	}
   376  
   377  	// Test with limit smaller than input.
   378  	var results []*sqltypes.Result
   379  	err := l.TryStreamExecute(context.Background(), &noopVCursor{}, bindVars, true, func(qr *sqltypes.Result) error {
   380  		results = append(results, qr)
   381  		return nil
   382  	})
   383  	require.NoError(t, err)
   384  	wantResults := sqltypes.MakeTestStreamingResults(
   385  		fields,
   386  		"a|1",
   387  		"b|2",
   388  	)
   389  	require.Len(t, results, len(wantResults))
   390  	for i, result := range results {
   391  		if !result.Equal(wantResults[i]) {
   392  			t.Errorf("l.StreamExecute:\n%s, want\n%s", sqltypes.PrintResults(results), sqltypes.PrintResults(wantResults))
   393  		}
   394  	}
   395  
   396  	// Test with bind vars.
   397  	fp.rewind()
   398  	l.Count = evalengine.NewBindVar("l", collations.TypedCollation{})
   399  	results = nil
   400  	err = l.TryStreamExecute(context.Background(), &noopVCursor{}, map[string]*querypb.BindVariable{"l": sqltypes.Int64BindVariable(2)}, true, func(qr *sqltypes.Result) error {
   401  		results = append(results, qr)
   402  		return nil
   403  	})
   404  	require.NoError(t, err)
   405  	require.Len(t, results, len(wantResults))
   406  	for i, result := range results {
   407  		if !result.Equal(wantResults[i]) {
   408  			t.Errorf("l.StreamExecute:\n%s, want\n%s", sqltypes.PrintResults(results), sqltypes.PrintResults(wantResults))
   409  		}
   410  	}
   411  
   412  	// Test with limit equal to input
   413  	fp.rewind()
   414  	l.Count = evalengine.NewLiteralInt(3)
   415  	results = nil
   416  	err = l.TryStreamExecute(context.Background(), &noopVCursor{}, bindVars, true, func(qr *sqltypes.Result) error {
   417  		results = append(results, qr)
   418  		return nil
   419  	})
   420  	require.NoError(t, err)
   421  	wantResults = sqltypes.MakeTestStreamingResults(
   422  		fields,
   423  		"a|1",
   424  		"b|2",
   425  		"---",
   426  		"c|3",
   427  	)
   428  	require.Len(t, results, len(wantResults))
   429  	for i, result := range results {
   430  		if !result.Equal(wantResults[i]) {
   431  			t.Errorf("l.StreamExecute:\n%s, want\n%s", sqltypes.PrintResults(results), sqltypes.PrintResults(wantResults))
   432  		}
   433  	}
   434  
   435  	// Test with limit higher than input.
   436  	fp.rewind()
   437  	l.Count = evalengine.NewLiteralInt(4)
   438  	results = nil
   439  	err = l.TryStreamExecute(context.Background(), &noopVCursor{}, bindVars, true, func(qr *sqltypes.Result) error {
   440  		results = append(results, qr)
   441  		return nil
   442  	})
   443  	require.NoError(t, err)
   444  	// wantResults is same as before.
   445  	require.Len(t, results, len(wantResults))
   446  	for i, result := range results {
   447  		if !result.Equal(wantResults[i]) {
   448  			t.Errorf("l.StreamExecute:\n%s, want\n%s", sqltypes.PrintResults(results), sqltypes.PrintResults(wantResults))
   449  		}
   450  	}
   451  }
   452  
   453  func TestOffsetStreamExecute(t *testing.T) {
   454  	bindVars := make(map[string]*querypb.BindVariable)
   455  	fields := sqltypes.MakeTestFields(
   456  		"col1|col2",
   457  		"int64|varchar",
   458  	)
   459  	inputResult := sqltypes.MakeTestResult(
   460  		fields,
   461  		"a|1",
   462  		"b|2",
   463  		"c|3",
   464  		"d|4",
   465  		"e|5",
   466  		"f|6",
   467  	)
   468  	fp := &fakePrimitive{
   469  		results: []*sqltypes.Result{inputResult},
   470  	}
   471  
   472  	l := &Limit{
   473  		Offset: evalengine.NewLiteralInt(2),
   474  		Count:  evalengine.NewLiteralInt(3),
   475  		Input:  fp,
   476  	}
   477  
   478  	var results []*sqltypes.Result
   479  	err := l.TryStreamExecute(context.Background(), &noopVCursor{}, bindVars, true, func(qr *sqltypes.Result) error {
   480  		results = append(results, qr)
   481  		return nil
   482  	})
   483  	require.NoError(t, err)
   484  	wantResults := sqltypes.MakeTestStreamingResults(
   485  		fields,
   486  		"c|3",
   487  		"d|4",
   488  		"---",
   489  		"e|5",
   490  	)
   491  	require.Len(t, results, len(wantResults))
   492  	for i, result := range results {
   493  		if !result.Equal(wantResults[i]) {
   494  			t.Errorf("l.StreamExecute:\n%s, want\n%s", sqltypes.PrintResults(results), sqltypes.PrintResults(wantResults))
   495  		}
   496  	}
   497  }
   498  
   499  func TestLimitGetFields(t *testing.T) {
   500  	result := sqltypes.MakeTestResult(
   501  		sqltypes.MakeTestFields(
   502  			"col1|col2",
   503  			"int64|varchar",
   504  		),
   505  	)
   506  	fp := &fakePrimitive{results: []*sqltypes.Result{result}}
   507  
   508  	l := &Limit{Input: fp}
   509  
   510  	got, err := l.GetFields(context.Background(), nil, nil)
   511  	require.NoError(t, err)
   512  	if !got.Equal(result) {
   513  		t.Errorf("l.GetFields:\n%v, want\n%v", got, result)
   514  	}
   515  }
   516  
   517  func TestLimitInputFail(t *testing.T) {
   518  	bindVars := make(map[string]*querypb.BindVariable)
   519  	fp := &fakePrimitive{sendErr: errors.New("input fail")}
   520  
   521  	l := &Limit{Count: evalengine.NewLiteralInt(1), Input: fp}
   522  
   523  	want := "input fail"
   524  	if _, err := l.TryExecute(context.Background(), &noopVCursor{}, bindVars, false); err == nil || err.Error() != want {
   525  		t.Errorf("l.Execute(): %v, want %s", err, want)
   526  	}
   527  
   528  	fp.rewind()
   529  	err := l.TryStreamExecute(context.Background(), &noopVCursor{}, bindVars, false, func(_ *sqltypes.Result) error { return nil })
   530  	if err == nil || err.Error() != want {
   531  		t.Errorf("l.StreamExecute(): %v, want %s", err, want)
   532  	}
   533  
   534  	fp.rewind()
   535  	if _, err := l.GetFields(context.Background(), nil, nil); err == nil || err.Error() != want {
   536  		t.Errorf("l.GetFields(): %v, want %s", err, want)
   537  	}
   538  }
   539  
   540  func TestLimitInvalidCount(t *testing.T) {
   541  	l := &Limit{
   542  		Count: evalengine.NewBindVar("l", collations.TypedCollation{}),
   543  	}
   544  	_, _, err := l.getCountAndOffset(&noopVCursor{}, nil)
   545  	assert.EqualError(t, err, "query arguments missing for l")
   546  
   547  	l.Count = evalengine.NewLiteralFloat(1.2)
   548  	_, _, err = l.getCountAndOffset(&noopVCursor{}, nil)
   549  	assert.EqualError(t, err, "Cannot convert value to desired type")
   550  
   551  	l.Count = evalengine.NewLiteralUint(18446744073709551615)
   552  	_, _, err = l.getCountAndOffset(&noopVCursor{}, nil)
   553  	assert.EqualError(t, err, "requested limit is out of range: 18446744073709551615")
   554  
   555  	// When going through the API, it should return the same error.
   556  	_, err = l.TryExecute(context.Background(), &noopVCursor{}, nil, false)
   557  	assert.EqualError(t, err, "requested limit is out of range: 18446744073709551615")
   558  
   559  	err = l.TryStreamExecute(context.Background(), &noopVCursor{}, nil, false, func(_ *sqltypes.Result) error { return nil })
   560  	assert.EqualError(t, err, "requested limit is out of range: 18446744073709551615")
   561  }