github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/comparison_test.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package expression_test
    16  
    17  import (
    18  	"testing"
    19  
    20  	"github.com/stretchr/testify/require"
    21  
    22  	"github.com/dolthub/go-mysql-server/internal/regex"
    23  	"github.com/dolthub/go-mysql-server/sql"
    24  	"github.com/dolthub/go-mysql-server/sql/expression"
    25  	"github.com/dolthub/go-mysql-server/sql/types"
    26  )
    27  
    28  const (
    29  	testEqual int = iota
    30  	testLess
    31  	testGreater
    32  	testRegexp
    33  	testNotRegexp
    34  	testNil
    35  )
    36  
    37  var comparisonCases = map[sql.Type]map[int][][]interface{}{
    38  	types.LongText: {
    39  		testEqual: {
    40  			{"foo", "foo"},
    41  			{"", ""},
    42  		},
    43  		testLess: {
    44  			{"a", "b"},
    45  			{"", "1"},
    46  		},
    47  		testGreater: {
    48  			{"b", "a"},
    49  			{"1", ""},
    50  		},
    51  		testNil: {
    52  			{nil, "a"},
    53  			{"a", nil},
    54  			{nil, nil},
    55  		},
    56  	},
    57  	types.Int32: {
    58  		testEqual: {
    59  			{int32(1), int32(1)},
    60  			{int32(0), int32(0)},
    61  		},
    62  		testLess: {
    63  			{int32(-1), int32(0)},
    64  			{int32(1), int32(2)},
    65  		},
    66  		testGreater: {
    67  			{int32(2), int32(1)},
    68  			{int32(0), int32(-1)},
    69  		},
    70  		testNil: {
    71  			{nil, int32(1)},
    72  			{int32(1), nil},
    73  			{nil, nil},
    74  		},
    75  	},
    76  }
    77  
    78  var likeComparisonCases = map[sql.Type]map[int][][]interface{}{
    79  	types.LongText: {
    80  		testRegexp: {
    81  			{"foobar", ".*bar"},
    82  			{"foobarfoo", ".*bar.*"},
    83  			{"bar", "bar"},
    84  			{"barfoo", "bar.*"},
    85  		},
    86  		testNotRegexp: {
    87  			{"foobara", ".*bar$"},
    88  			{"foofoo", ".*bar.*"},
    89  			{"bara", "bar$"},
    90  			{"abarfoo", "^bar.*"},
    91  		},
    92  		testNil: {
    93  			{"foobar", nil},
    94  			{nil, ".*bar"},
    95  			{nil, nil},
    96  		},
    97  	},
    98  	types.Int32: {
    99  		testRegexp: {
   100  			{int32(1), int32(1)},
   101  			{int32(0), int32(0)},
   102  		},
   103  		testNotRegexp: {
   104  			{int32(-1), int32(0)},
   105  			{int32(1), int32(2)},
   106  		},
   107  	},
   108  }
   109  
   110  func TestEquals(t *testing.T) {
   111  	require := require.New(t)
   112  	for resultType, cmpCase := range comparisonCases {
   113  		get0 := expression.NewGetField(0, resultType, "col1", true)
   114  		require.NotNil(get0)
   115  		get1 := expression.NewGetField(1, resultType, "col2", true)
   116  		require.NotNil(get1)
   117  		eq := expression.NewEquals(get0, get1)
   118  		require.NotNil(eq)
   119  		require.Equal(types.Boolean, eq.Type())
   120  		for cmpResult, cases := range cmpCase {
   121  			for _, pair := range cases {
   122  				row := sql.NewRow(pair[0], pair[1])
   123  				require.NotNil(row)
   124  				cmp := eval(t, eq, row)
   125  				if cmpResult == testEqual {
   126  					require.Equal(true, cmp)
   127  				} else if cmpResult == testNil {
   128  					require.Nil(cmp)
   129  				} else {
   130  					require.Equal(false, cmp)
   131  				}
   132  			}
   133  		}
   134  	}
   135  }
   136  
   137  func TestNullSafeEquals(t *testing.T) {
   138  	require := require.New(t)
   139  	for resultType, cmpCase := range comparisonCases {
   140  		get0 := expression.NewGetField(0, resultType, "col1", true)
   141  		require.NotNil(get0)
   142  		get1 := expression.NewGetField(1, resultType, "col2", true)
   143  		require.NotNil(get1)
   144  		seq := expression.NewNullSafeEquals(get0, get1)
   145  		require.NotNil(seq)
   146  		require.Equal(types.Boolean, seq.Type())
   147  		for cmpResult, cases := range cmpCase {
   148  			for _, pair := range cases {
   149  				row := sql.NewRow(pair[0], pair[1])
   150  				require.NotNil(row)
   151  				cmp := eval(t, seq, row)
   152  				if cmpResult == testEqual {
   153  					require.Equal(true, cmp)
   154  				} else if cmpResult == testNil {
   155  					if pair[0] == nil && pair[1] == nil {
   156  						require.Equal(true, cmp)
   157  					} else {
   158  						require.Equal(false, cmp)
   159  					}
   160  				} else {
   161  					require.Equal(false, cmp)
   162  				}
   163  			}
   164  		}
   165  	}
   166  }
   167  
   168  func TestLessThan(t *testing.T) {
   169  	require := require.New(t)
   170  	for resultType, cmpCase := range comparisonCases {
   171  		get0 := expression.NewGetField(0, resultType, "col1", true)
   172  		require.NotNil(get0)
   173  		get1 := expression.NewGetField(1, resultType, "col2", true)
   174  		require.NotNil(get1)
   175  		eq := expression.NewLessThan(get0, get1)
   176  		require.NotNil(eq)
   177  		require.Equal(types.Boolean, eq.Type())
   178  		for cmpResult, cases := range cmpCase {
   179  			for _, pair := range cases {
   180  				row := sql.NewRow(pair[0], pair[1])
   181  				require.NotNil(row)
   182  				cmp := eval(t, eq, row)
   183  				if cmpResult == testLess {
   184  					require.Equal(true, cmp, "%v < %v", pair[0], pair[1])
   185  				} else if cmpResult == testNil {
   186  					require.Nil(cmp)
   187  				} else {
   188  					require.Equal(false, cmp)
   189  				}
   190  			}
   191  		}
   192  	}
   193  }
   194  
   195  func TestGreaterThan(t *testing.T) {
   196  	require := require.New(t)
   197  	for resultType, cmpCase := range comparisonCases {
   198  		get0 := expression.NewGetField(0, resultType, "col1", true)
   199  		require.NotNil(get0)
   200  		get1 := expression.NewGetField(1, resultType, "col2", true)
   201  		require.NotNil(get1)
   202  		eq := expression.NewGreaterThan(get0, get1)
   203  		require.NotNil(eq)
   204  		require.Equal(types.Boolean, eq.Type())
   205  		for cmpResult, cases := range cmpCase {
   206  			for _, pair := range cases {
   207  				row := sql.NewRow(pair[0], pair[1])
   208  				require.NotNil(row)
   209  				cmp := eval(t, eq, row)
   210  				if cmpResult == testGreater {
   211  					require.Equal(true, cmp)
   212  				} else if cmpResult == testNil {
   213  					require.Nil(cmp)
   214  				} else {
   215  					require.Equal(false, cmp)
   216  				}
   217  			}
   218  		}
   219  	}
   220  }
   221  
   222  func TestRegexp(t *testing.T) {
   223  	for _, engine := range regex.Engines() {
   224  		regex.SetDefault(engine)
   225  		t.Run(engine, testRegexpCases)
   226  	}
   227  }
   228  
   229  func testRegexpCases(t *testing.T) {
   230  	t.Helper()
   231  	require := require.New(t)
   232  
   233  	for resultType, cmpCase := range likeComparisonCases {
   234  		get0 := expression.NewGetField(0, resultType, "col1", true)
   235  		require.NotNil(get0)
   236  		get1 := expression.NewGetField(1, resultType, "col2", true)
   237  		require.NotNil(get1)
   238  		for cmpResult, cases := range cmpCase {
   239  			for _, pair := range cases {
   240  				eq := expression.NewRegexp(get0, get1)
   241  				require.NotNil(eq)
   242  				require.Equal(types.Boolean, eq.Type())
   243  
   244  				row := sql.NewRow(pair[0], pair[1])
   245  				require.NotNil(row)
   246  				cmp := eval(t, eq, row)
   247  				if cmpResult == testRegexp {
   248  					require.Equal(true, cmp)
   249  				} else if cmpResult == testNil {
   250  					require.Nil(cmp)
   251  				} else {
   252  					require.Equal(false, cmp)
   253  				}
   254  			}
   255  		}
   256  	}
   257  }
   258  
   259  func TestInvalidRegexp(t *testing.T) {
   260  	t.Helper()
   261  	require := require.New(t)
   262  
   263  	col1 := expression.NewGetField(0, types.LongText, "col1", true)
   264  	invalid := expression.NewLiteral("*col1", types.LongText)
   265  	r := expression.NewRegexp(col1, invalid)
   266  	row := sql.NewRow("col1")
   267  
   268  	_, err := r.Eval(sql.NewEmptyContext(), row)
   269  	require.Error(err)
   270  }
   271  
   272  func eval(t *testing.T, e sql.Expression, row sql.Row) interface{} {
   273  	t.Helper()
   274  	v, err := e.Eval(sql.NewEmptyContext(), row)
   275  	require.NoError(t, err)
   276  	return v
   277  }