github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/in_test.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package expression_test
    16  
    17  import (
    18  	"testing"
    19  	"time"
    20  
    21  	"github.com/dolthub/vitess/go/sqltypes"
    22  	"github.com/stretchr/testify/assert"
    23  	"github.com/stretchr/testify/require"
    24  	"gopkg.in/src-d/go-errors.v1"
    25  
    26  	"github.com/dolthub/go-mysql-server/sql"
    27  	"github.com/dolthub/go-mysql-server/sql/expression"
    28  	"github.com/dolthub/go-mysql-server/sql/expression/function"
    29  	"github.com/dolthub/go-mysql-server/sql/types"
    30  )
    31  
    32  var testEnumType = types.MustCreateEnumType([]string{"", "one", "two"}, sql.Collation_Default)
    33  
    34  var testSetType = types.MustCreateSetType([]string{"", "one", "two"}, sql.Collation_Default)
    35  
    36  func TestRoundTripNames(t *testing.T) {
    37  	assert.Equal(t, "(foo IN (foo, 2))", expression.NewInTuple(expression.NewGetField(0, types.Int64, "foo", false),
    38  		expression.NewTuple(
    39  			expression.NewGetField(0, types.Int64, "foo", false),
    40  			expression.NewLiteral(int64(2), types.Int64),
    41  		)).String())
    42  	hit, err := expression.NewHashInTuple(nil, expression.NewGetField(0, types.Int64, "foo", false),
    43  		expression.NewTuple(
    44  			expression.NewLiteral(int64(2), types.Int64),
    45  		))
    46  	assert.NoError(t, err)
    47  	assert.Equal(t, "(foo HASH IN (2))", hit.String())
    48  }
    49  
    50  func TestInTuple(t *testing.T) {
    51  	testCases := []struct {
    52  		name   string
    53  		left   sql.Expression
    54  		right  sql.Expression
    55  		row    sql.Row
    56  		result interface{}
    57  		err    *errors.Kind
    58  	}{
    59  		{
    60  			"left is nil",
    61  			expression.NewLiteral(nil, types.Null),
    62  			expression.NewTuple(
    63  				expression.NewLiteral(int64(1), types.Int64),
    64  				expression.NewLiteral(int64(2), types.Int64),
    65  			),
    66  			nil,
    67  			nil,
    68  			nil,
    69  		},
    70  		{
    71  			"left and right don't have the same cols",
    72  			expression.NewLiteral(1, types.Int64),
    73  			expression.NewTuple(
    74  				expression.NewTuple(
    75  					expression.NewLiteral(int64(1), types.Int64),
    76  					expression.NewLiteral(int64(1), types.Int64),
    77  				),
    78  				expression.NewLiteral(int64(2), types.Int64),
    79  			),
    80  			nil,
    81  			nil,
    82  			sql.ErrInvalidOperandColumns,
    83  		},
    84  		{
    85  			"right is an unsupported operand",
    86  			expression.NewLiteral(1, types.Int64),
    87  			expression.NewLiteral(int64(2), types.Int64),
    88  			nil,
    89  			nil,
    90  			expression.ErrUnsupportedInOperand,
    91  		},
    92  		{
    93  			"left is in right",
    94  			expression.NewGetField(0, types.Int64, "foo", false),
    95  			expression.NewTuple(
    96  				expression.NewGetField(0, types.Int64, "foo", false),
    97  				expression.NewLiteral(int64(2), types.Int64),
    98  			),
    99  			sql.NewRow(int64(1)),
   100  			true,
   101  			nil,
   102  		},
   103  		{
   104  			"left is not in right",
   105  			expression.NewGetField(0, types.Int64, "foo", false),
   106  			expression.NewTuple(
   107  				expression.NewGetField(1, types.Int64, "bar", false),
   108  				expression.NewLiteral(int64(2), types.Int64),
   109  			),
   110  			sql.NewRow(int64(1), int64(3)),
   111  			false,
   112  			nil,
   113  		},
   114  		{
   115  			name: "right values contain a different, coercible type",
   116  			left: expression.NewLiteral(1, types.Uint64),
   117  			right: expression.NewTuple(
   118  				expression.NewLiteral("hi", types.TinyText),
   119  				expression.NewLiteral("bye", types.TinyText),
   120  			),
   121  			row:    nil,
   122  			result: false,
   123  		},
   124  		{
   125  			name: "right values contain a different, coercible type, and left value is zero value",
   126  			left: expression.NewLiteral(0, types.Uint64),
   127  			right: expression.NewTuple(
   128  				expression.NewLiteral("hi", types.TinyText),
   129  				expression.NewLiteral("bye", types.TinyText),
   130  			),
   131  			row:    nil,
   132  			result: true,
   133  		},
   134  		{
   135  			name: "enum on left side; invalid values on right",
   136  			left: expression.NewLiteral("one", testEnumType),
   137  			right: expression.NewTuple(
   138  				expression.NewLiteral("hi", types.TinyText),
   139  				expression.NewLiteral("bye", types.TinyText),
   140  			),
   141  			row:    nil,
   142  			result: false,
   143  		},
   144  		{
   145  			name: "enum on left side; valid enum values on right",
   146  			left: expression.NewLiteral("one", testEnumType),
   147  			right: expression.NewTuple(
   148  				expression.NewLiteral("", types.TinyText),
   149  				expression.NewLiteral("one", types.TinyText),
   150  			),
   151  			row:    nil,
   152  			result: true,
   153  		},
   154  		{
   155  			name: "set on left side; invalid set values on right",
   156  			left: expression.NewLiteral("one", testSetType),
   157  			right: expression.NewTuple(
   158  				expression.NewLiteral("hi", types.TinyText),
   159  				expression.NewLiteral("bye", types.TinyText),
   160  			),
   161  			row:    nil,
   162  			result: false,
   163  		},
   164  		{
   165  			name: "set on left side; valid set values on right",
   166  			left: expression.NewLiteral("one", testSetType),
   167  			right: expression.NewTuple(
   168  				expression.NewLiteral("", types.TinyText),
   169  				expression.NewLiteral("one", types.TinyText),
   170  			),
   171  			row:    nil,
   172  			result: true,
   173  		},
   174  		{
   175  			name: "date on right side; non-dates on left",
   176  			left: expression.NewLiteral(time.Now(), types.DatetimeMaxPrecision),
   177  			right: expression.NewTuple(
   178  				expression.NewLiteral("hi", types.TinyText),
   179  				expression.NewLiteral("bye", types.TinyText),
   180  			),
   181  			err:    types.ErrConvertingToTime,
   182  			row:    nil,
   183  			result: false,
   184  		}}
   185  
   186  	for _, tt := range testCases {
   187  		t.Run(tt.name, func(t *testing.T) {
   188  			require := require.New(t)
   189  
   190  			result, err := expression.NewInTuple(tt.left, tt.right).
   191  				Eval(sql.NewEmptyContext(), tt.row)
   192  			if tt.err != nil {
   193  				require.Error(err)
   194  				require.True(tt.err.Is(err))
   195  			} else {
   196  				require.NoError(err)
   197  				require.Equal(tt.result, result)
   198  			}
   199  		})
   200  	}
   201  }
   202  
   203  func TestNotInTuple(t *testing.T) {
   204  	testCases := []struct {
   205  		name   string
   206  		left   sql.Expression
   207  		right  sql.Expression
   208  		row    sql.Row
   209  		result interface{}
   210  		err    *errors.Kind
   211  	}{
   212  		{
   213  			"left is nil",
   214  			expression.NewLiteral(nil, types.Null),
   215  			expression.NewTuple(
   216  				expression.NewLiteral(int64(1), types.Int64),
   217  				expression.NewLiteral(int64(2), types.Int64),
   218  			),
   219  			nil,
   220  			nil,
   221  			nil,
   222  		},
   223  		{
   224  			"left and right don't have the same cols",
   225  			expression.NewLiteral(1, types.Int64),
   226  			expression.NewTuple(
   227  				expression.NewTuple(
   228  					expression.NewLiteral(int64(1), types.Int64),
   229  					expression.NewLiteral(int64(1), types.Int64),
   230  				),
   231  				expression.NewLiteral(int64(2), types.Int64),
   232  			),
   233  			nil,
   234  			nil,
   235  			sql.ErrInvalidOperandColumns,
   236  		},
   237  		{
   238  			"right is an unsupported operand",
   239  			expression.NewLiteral(1, types.Int64),
   240  			expression.NewLiteral(int64(2), types.Int64),
   241  			nil,
   242  			nil,
   243  			expression.ErrUnsupportedInOperand,
   244  		},
   245  		{
   246  			"left is in right",
   247  			expression.NewGetField(0, types.Int64, "foo", false),
   248  			expression.NewTuple(
   249  				expression.NewGetField(0, types.Int64, "foo", false),
   250  				expression.NewLiteral(int64(2), types.Int64),
   251  			),
   252  			sql.NewRow(int64(1)),
   253  			false,
   254  			nil,
   255  		},
   256  		{
   257  			"left is not in right",
   258  			expression.NewGetField(0, types.Int64, "foo", false),
   259  			expression.NewTuple(
   260  				expression.NewGetField(1, types.Int64, "bar", false),
   261  				expression.NewLiteral(int64(2), types.Int64),
   262  			),
   263  			sql.NewRow(int64(1), int64(3)),
   264  			true,
   265  			nil,
   266  		},
   267  	}
   268  
   269  	for _, tt := range testCases {
   270  		t.Run(tt.name, func(t *testing.T) {
   271  			require := require.New(t)
   272  
   273  			result, err := expression.NewNotInTuple(tt.left, tt.right).
   274  				Eval(sql.NewEmptyContext(), tt.row)
   275  			if tt.err != nil {
   276  				require.Error(err)
   277  				require.True(tt.err.Is(err))
   278  			} else {
   279  				require.NoError(err)
   280  				require.Equal(tt.result, result)
   281  			}
   282  		})
   283  	}
   284  }
   285  
   286  func TestHashInTuple(t *testing.T) {
   287  	testCases := []struct {
   288  		name      string
   289  		left      sql.Expression
   290  		right     sql.Expression
   291  		row       sql.Row
   292  		result    interface{}
   293  		staticErr *errors.Kind
   294  		evalErr   *errors.Kind
   295  	}{
   296  		{
   297  			"left is nil",
   298  			expression.NewLiteral(nil, types.Null),
   299  			expression.NewTuple(
   300  				expression.NewLiteral(int64(1), types.Int64),
   301  				expression.NewLiteral(int64(2), types.Int64),
   302  			),
   303  			nil,
   304  			nil,
   305  			nil,
   306  			nil,
   307  		},
   308  		{
   309  			"left and right don't have the same number of cols; right has tuple",
   310  			expression.NewLiteral(1, types.Int64),
   311  			expression.NewTuple(
   312  				expression.NewTuple(
   313  					expression.NewLiteral(int64(1), types.Int64),
   314  					expression.NewLiteral(int64(1), types.Int64),
   315  				),
   316  				expression.NewLiteral(int64(2), types.Int64),
   317  			),
   318  			nil,
   319  			false,
   320  			sql.ErrInvalidOperandColumns,
   321  			nil,
   322  		},
   323  		{
   324  			"left and right don't have the same number of cols; left has tuple",
   325  			expression.NewTuple(
   326  				expression.NewLiteral(1, types.Int64),
   327  				expression.NewLiteral(0, types.Int64),
   328  			),
   329  			expression.NewTuple(
   330  				expression.NewTuple(
   331  					expression.NewLiteral(int64(1), types.Int64),
   332  					expression.NewLiteral(int64(1), types.Int64),
   333  				),
   334  				expression.NewLiteral(int64(2), types.Int64),
   335  			),
   336  			nil,
   337  			false,
   338  			sql.ErrInvalidOperandColumns,
   339  			nil,
   340  		},
   341  		{
   342  			"right is an unsupported operand",
   343  			expression.NewLiteral(1, types.Int64),
   344  			expression.NewLiteral(int64(2), types.Int64),
   345  			nil,
   346  			nil,
   347  			expression.ErrUnsupportedInOperand,
   348  			nil,
   349  		},
   350  		{
   351  			"left is in right",
   352  			expression.NewGetField(0, types.Int64, "foo", false),
   353  			expression.NewTuple(
   354  				expression.NewLiteral(int64(2), types.Int64),
   355  				expression.NewLiteral(int64(1), types.Int64),
   356  				expression.NewLiteral(int64(0), types.Int64),
   357  			),
   358  			sql.NewRow(int64(1)),
   359  			true,
   360  			nil,
   361  			nil,
   362  		},
   363  		{
   364  			"left is not in right",
   365  			expression.NewGetField(0, types.Int64, "foo", false),
   366  			expression.NewTuple(
   367  				expression.NewLiteral(int64(0), types.Int64),
   368  				expression.NewLiteral(int64(2), types.Int64),
   369  			),
   370  			sql.NewRow(int64(1), int64(3)),
   371  			false,
   372  			nil,
   373  			nil,
   374  		},
   375  		{
   376  			"left tuple is in right",
   377  			expression.NewTuple(
   378  				expression.NewLiteral(int64(2), types.Int64),
   379  				expression.NewLiteral(int64(1), types.Int64),
   380  			),
   381  			expression.NewTuple(
   382  				expression.NewTuple(
   383  					expression.NewLiteral(int64(2), types.Int64),
   384  					expression.NewLiteral(int64(1), types.Int64),
   385  				),
   386  				expression.NewTuple(
   387  					expression.NewLiteral(int64(1), types.Int64),
   388  					expression.NewLiteral(int64(0), types.Int64),
   389  				),
   390  			),
   391  			nil,
   392  			true,
   393  			nil,
   394  			nil,
   395  		},
   396  		{
   397  			"heterogeneous left tuple is in right",
   398  			expression.NewTuple(
   399  				expression.NewLiteral(int64(2), types.Int64),
   400  				expression.NewLiteral("a", types.MustCreateStringWithDefaults(sqltypes.VarChar, 20)),
   401  			),
   402  			expression.NewTuple(
   403  				expression.NewTuple(
   404  					expression.NewLiteral(int64(1), types.Int64),
   405  					expression.NewLiteral("b", types.MustCreateStringWithDefaults(sqltypes.VarChar, 20)),
   406  				),
   407  				expression.NewTuple(
   408  					expression.NewLiteral(int64(2), types.Int64),
   409  					expression.NewLiteral("a", types.MustCreateStringWithDefaults(sqltypes.VarChar, 20)),
   410  				),
   411  			),
   412  			nil,
   413  			true,
   414  			nil,
   415  			nil,
   416  		},
   417  		{
   418  			"left get field tuple is in right",
   419  			expression.NewTuple(
   420  				expression.NewGetField(0, types.Int64, "foo", false),
   421  				expression.NewGetField(1, types.Int64, "foo", false),
   422  			),
   423  			expression.NewTuple(
   424  				expression.NewTuple(
   425  					expression.NewLiteral(int64(2), types.Int64),
   426  					expression.NewLiteral(int64(1), types.Int64),
   427  				),
   428  				expression.NewTuple(
   429  					expression.NewLiteral(int64(1), types.Int64),
   430  					expression.NewLiteral(int64(0), types.Int64),
   431  				),
   432  			),
   433  			sql.NewRow(int64(1), int64(0)),
   434  			true,
   435  			nil,
   436  			nil,
   437  		},
   438  		{
   439  			"left nested tuple is in right",
   440  			expression.NewTuple(
   441  				expression.NewTuple(
   442  					expression.NewLiteral(int64(2), types.Int64),
   443  					expression.NewLiteral(int64(1), types.Int64),
   444  				),
   445  				expression.NewLiteral(int64(1), types.Int64),
   446  			),
   447  			expression.NewTuple(
   448  				expression.NewTuple(
   449  					expression.NewTuple(
   450  						expression.NewLiteral(int64(2), types.Int64),
   451  						expression.NewLiteral(int64(1), types.Int64),
   452  					),
   453  					expression.NewLiteral(int64(1), types.Int64),
   454  				),
   455  				expression.NewTuple(
   456  					expression.NewTuple(
   457  						expression.NewLiteral(int64(1), types.Int64),
   458  						expression.NewLiteral(int64(2), types.Int64),
   459  					),
   460  					expression.NewLiteral(int64(0), types.Int64),
   461  				),
   462  			),
   463  			nil,
   464  			true,
   465  			nil,
   466  			nil,
   467  		},
   468  		{
   469  			name: "left has a function",
   470  			left: expression.NewTuple(
   471  				function.NewLower(
   472  					expression.NewLiteral("hi", types.TinyText),
   473  				),
   474  			),
   475  			right: expression.NewTuple(
   476  				expression.NewLiteral("hi", types.TinyText),
   477  			),
   478  			result: true,
   479  		},
   480  		{
   481  			name: "right values contain a different, coercible type",
   482  			left: expression.NewLiteral(1, types.Uint64),
   483  			right: expression.NewTuple(
   484  				expression.NewLiteral("hi", types.TinyText),
   485  				expression.NewLiteral("bye", types.TinyText),
   486  			),
   487  			row:    nil,
   488  			result: false,
   489  		},
   490  		{
   491  			name: "right values contain zero floats that are equal to the left value",
   492  			left: expression.NewLiteral(0, types.Uint64),
   493  			right: expression.NewTuple(
   494  				expression.NewLiteral(0.0, types.Float64),
   495  				expression.NewLiteral(1.23, types.Float64),
   496  			),
   497  			row:    nil,
   498  			result: true,
   499  		},
   500  		{
   501  			name: "right values contain floats that are equal to the left value",
   502  			left: expression.NewLiteral(1, types.Uint64),
   503  			right: expression.NewTuple(
   504  				expression.NewLiteral(1.0, types.Float64),
   505  				expression.NewLiteral(1.23, types.Float64),
   506  			),
   507  			row:    nil,
   508  			result: true,
   509  		},
   510  		{
   511  			name: "right values contain decimals that are equal to the left value",
   512  			left: expression.NewLiteral(1, types.Uint64),
   513  			right: expression.NewTuple(
   514  				expression.NewLiteral(1.0, types.MustCreateDecimalType(10, 5)),
   515  				expression.NewLiteral(1.23, types.MustCreateDecimalType(10, 5)),
   516  			),
   517  			row:    nil,
   518  			result: true,
   519  		},
   520  		{
   521  			name: "right values contain a different, coercible type, and left value is zero value",
   522  			left: expression.NewLiteral(0, types.Uint64),
   523  			right: expression.NewTuple(
   524  				expression.NewLiteral("hi", types.TinyText),
   525  				expression.NewLiteral("bye", types.TinyText),
   526  			),
   527  			row:    nil,
   528  			result: true,
   529  		},
   530  		{
   531  			name: "enum on left side; invalid values on right",
   532  			left: expression.NewLiteral("one", testEnumType),
   533  			right: expression.NewTuple(
   534  				expression.NewLiteral("hi", types.TinyText),
   535  				expression.NewLiteral("bye", types.TinyText),
   536  			),
   537  			row:    nil,
   538  			result: false,
   539  		},
   540  		{
   541  			name: "enum on left side; valid enum values on right",
   542  			left: expression.NewLiteral("one", testEnumType),
   543  			right: expression.NewTuple(
   544  				expression.NewLiteral("", types.TinyText),
   545  				expression.NewLiteral("one", types.TinyText),
   546  			),
   547  			row:    nil,
   548  			result: true,
   549  		},
   550  		{
   551  			name: "set on left side; invalid set values on right",
   552  			left: expression.NewLiteral("one", testSetType),
   553  			right: expression.NewTuple(
   554  				expression.NewLiteral("hi", types.TinyText),
   555  				expression.NewLiteral("bye", types.TinyText),
   556  			),
   557  			row:    nil,
   558  			result: false,
   559  		},
   560  		{
   561  			name: "set on left side; valid set values on right",
   562  			left: expression.NewLiteral("one", testSetType),
   563  			right: expression.NewTuple(
   564  				expression.NewLiteral("", types.TinyText),
   565  				expression.NewLiteral("one", types.TinyText),
   566  			),
   567  			row:    nil,
   568  			result: true,
   569  		},
   570  		{
   571  			name: "date on right side; non-dates on left",
   572  			left: expression.NewLiteral(time.Now(), types.DatetimeMaxPrecision),
   573  			right: expression.NewTuple(
   574  				expression.NewLiteral("hi", types.TinyText),
   575  				expression.NewLiteral("bye", types.TinyText),
   576  			),
   577  			staticErr: types.ErrConvertingToTime,
   578  			row:       nil,
   579  			result:    false,
   580  		},
   581  		{
   582  			name: "left has a convert (type cast)",
   583  			left: expression.NewConvert(
   584  				expression.NewGetField(0, types.Int64, "foo", false),
   585  				"char",
   586  			),
   587  			right: expression.NewTuple(
   588  				expression.NewLiteral("1", types.TinyText),
   589  			),
   590  			row: sql.NewRow(int64(1), int64(0)),
   591  
   592  			result: true,
   593  		},
   594  		{
   595  			name: "left has a comparer",
   596  			left: expression.NewGreaterThan(
   597  				expression.NewGetField(0, types.Int64, "foo", false),
   598  				expression.NewLiteral(1, types.Int64),
   599  			),
   600  			right: expression.NewTuple(
   601  				expression.NewLiteral(true, types.Boolean),
   602  			),
   603  			row:    sql.NewRow(int64(2), int64(0)),
   604  			result: true,
   605  		},
   606  		{
   607  			name: "left has an is null",
   608  			left: expression.NewIsNull(
   609  				expression.NewLiteral(nil, types.Null),
   610  			),
   611  			right: expression.NewTuple(
   612  				expression.NewLiteral(true, types.Boolean),
   613  			),
   614  			result: true,
   615  		},
   616  		{
   617  			name: "left has an is true",
   618  			left: expression.NewIsTrue(
   619  				expression.NewLiteral(true, types.Boolean),
   620  			),
   621  			right: expression.NewTuple(
   622  				expression.NewLiteral(true, types.Boolean),
   623  			),
   624  			result: true,
   625  		},
   626  		{
   627  			name: "left has an arithmetic",
   628  			left: expression.NewPlus(
   629  				expression.NewLiteral(4, types.Int64),
   630  				expression.NewGetField(0, types.Int64, "foo", false),
   631  			),
   632  			right: expression.NewTuple(
   633  				expression.NewLiteral(6, types.Int64),
   634  			),
   635  			row:    sql.NewRow(int64(2), int64(0)),
   636  			result: true,
   637  		},
   638  	}
   639  
   640  	for _, tt := range testCases {
   641  		t.Run(tt.name, func(t *testing.T) {
   642  			ctx := sql.NewEmptyContext()
   643  			require := require.New(t)
   644  			expr, err := expression.NewHashInTuple(ctx, tt.left, tt.right)
   645  			if tt.staticErr != nil {
   646  				require.Error(err)
   647  				require.True(tt.staticErr.Is(err))
   648  			} else {
   649  				require.NoError(err)
   650  				result, err := expr.Eval(ctx, tt.row)
   651  				if tt.evalErr != nil {
   652  					require.Error(err)
   653  					require.True(tt.evalErr.Is(err))
   654  				} else {
   655  					require.NoError(err)
   656  					require.Equal(tt.result, result)
   657  				}
   658  			}
   659  		})
   660  	}
   661  }