github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/comparison_test.go (about) 1 // Copyright 2020-2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package expression_test 16 17 import ( 18 "testing" 19 20 "github.com/stretchr/testify/require" 21 22 "github.com/dolthub/go-mysql-server/internal/regex" 23 "github.com/dolthub/go-mysql-server/sql" 24 "github.com/dolthub/go-mysql-server/sql/expression" 25 "github.com/dolthub/go-mysql-server/sql/types" 26 ) 27 28 const ( 29 testEqual int = iota 30 testLess 31 testGreater 32 testRegexp 33 testNotRegexp 34 testNil 35 ) 36 37 var comparisonCases = map[sql.Type]map[int][][]interface{}{ 38 types.LongText: { 39 testEqual: { 40 {"foo", "foo"}, 41 {"", ""}, 42 }, 43 testLess: { 44 {"a", "b"}, 45 {"", "1"}, 46 }, 47 testGreater: { 48 {"b", "a"}, 49 {"1", ""}, 50 }, 51 testNil: { 52 {nil, "a"}, 53 {"a", nil}, 54 {nil, nil}, 55 }, 56 }, 57 types.Int32: { 58 testEqual: { 59 {int32(1), int32(1)}, 60 {int32(0), int32(0)}, 61 }, 62 testLess: { 63 {int32(-1), int32(0)}, 64 {int32(1), int32(2)}, 65 }, 66 testGreater: { 67 {int32(2), int32(1)}, 68 {int32(0), int32(-1)}, 69 }, 70 testNil: { 71 {nil, int32(1)}, 72 {int32(1), nil}, 73 {nil, nil}, 74 }, 75 }, 76 } 77 78 var likeComparisonCases = map[sql.Type]map[int][][]interface{}{ 79 types.LongText: { 80 testRegexp: { 81 {"foobar", ".*bar"}, 82 {"foobarfoo", ".*bar.*"}, 83 {"bar", "bar"}, 84 {"barfoo", "bar.*"}, 85 }, 86 testNotRegexp: { 87 {"foobara", ".*bar$"}, 88 {"foofoo", ".*bar.*"}, 89 {"bara", "bar$"}, 90 {"abarfoo", "^bar.*"}, 91 }, 92 testNil: { 93 {"foobar", nil}, 94 {nil, ".*bar"}, 95 {nil, nil}, 96 }, 97 }, 98 types.Int32: { 99 testRegexp: { 100 {int32(1), int32(1)}, 101 {int32(0), int32(0)}, 102 }, 103 testNotRegexp: { 104 {int32(-1), int32(0)}, 105 {int32(1), int32(2)}, 106 }, 107 }, 108 } 109 110 func TestEquals(t *testing.T) { 111 require := require.New(t) 112 for resultType, cmpCase := range comparisonCases { 113 get0 := expression.NewGetField(0, resultType, "col1", true) 114 require.NotNil(get0) 115 get1 := expression.NewGetField(1, resultType, "col2", true) 116 require.NotNil(get1) 117 eq := expression.NewEquals(get0, get1) 118 require.NotNil(eq) 119 require.Equal(types.Boolean, eq.Type()) 120 for cmpResult, cases := range cmpCase { 121 for _, pair := range cases { 122 row := sql.NewRow(pair[0], pair[1]) 123 require.NotNil(row) 124 cmp := eval(t, eq, row) 125 if cmpResult == testEqual { 126 require.Equal(true, cmp) 127 } else if cmpResult == testNil { 128 require.Nil(cmp) 129 } else { 130 require.Equal(false, cmp) 131 } 132 } 133 } 134 } 135 } 136 137 func TestNullSafeEquals(t *testing.T) { 138 require := require.New(t) 139 for resultType, cmpCase := range comparisonCases { 140 get0 := expression.NewGetField(0, resultType, "col1", true) 141 require.NotNil(get0) 142 get1 := expression.NewGetField(1, resultType, "col2", true) 143 require.NotNil(get1) 144 seq := expression.NewNullSafeEquals(get0, get1) 145 require.NotNil(seq) 146 require.Equal(types.Boolean, seq.Type()) 147 for cmpResult, cases := range cmpCase { 148 for _, pair := range cases { 149 row := sql.NewRow(pair[0], pair[1]) 150 require.NotNil(row) 151 cmp := eval(t, seq, row) 152 if cmpResult == testEqual { 153 require.Equal(true, cmp) 154 } else if cmpResult == testNil { 155 if pair[0] == nil && pair[1] == nil { 156 require.Equal(true, cmp) 157 } else { 158 require.Equal(false, cmp) 159 } 160 } else { 161 require.Equal(false, cmp) 162 } 163 } 164 } 165 } 166 } 167 168 func TestLessThan(t *testing.T) { 169 require := require.New(t) 170 for resultType, cmpCase := range comparisonCases { 171 get0 := expression.NewGetField(0, resultType, "col1", true) 172 require.NotNil(get0) 173 get1 := expression.NewGetField(1, resultType, "col2", true) 174 require.NotNil(get1) 175 eq := expression.NewLessThan(get0, get1) 176 require.NotNil(eq) 177 require.Equal(types.Boolean, eq.Type()) 178 for cmpResult, cases := range cmpCase { 179 for _, pair := range cases { 180 row := sql.NewRow(pair[0], pair[1]) 181 require.NotNil(row) 182 cmp := eval(t, eq, row) 183 if cmpResult == testLess { 184 require.Equal(true, cmp, "%v < %v", pair[0], pair[1]) 185 } else if cmpResult == testNil { 186 require.Nil(cmp) 187 } else { 188 require.Equal(false, cmp) 189 } 190 } 191 } 192 } 193 } 194 195 func TestGreaterThan(t *testing.T) { 196 require := require.New(t) 197 for resultType, cmpCase := range comparisonCases { 198 get0 := expression.NewGetField(0, resultType, "col1", true) 199 require.NotNil(get0) 200 get1 := expression.NewGetField(1, resultType, "col2", true) 201 require.NotNil(get1) 202 eq := expression.NewGreaterThan(get0, get1) 203 require.NotNil(eq) 204 require.Equal(types.Boolean, eq.Type()) 205 for cmpResult, cases := range cmpCase { 206 for _, pair := range cases { 207 row := sql.NewRow(pair[0], pair[1]) 208 require.NotNil(row) 209 cmp := eval(t, eq, row) 210 if cmpResult == testGreater { 211 require.Equal(true, cmp) 212 } else if cmpResult == testNil { 213 require.Nil(cmp) 214 } else { 215 require.Equal(false, cmp) 216 } 217 } 218 } 219 } 220 } 221 222 func TestRegexp(t *testing.T) { 223 for _, engine := range regex.Engines() { 224 regex.SetDefault(engine) 225 t.Run(engine, testRegexpCases) 226 } 227 } 228 229 func testRegexpCases(t *testing.T) { 230 t.Helper() 231 require := require.New(t) 232 233 for resultType, cmpCase := range likeComparisonCases { 234 get0 := expression.NewGetField(0, resultType, "col1", true) 235 require.NotNil(get0) 236 get1 := expression.NewGetField(1, resultType, "col2", true) 237 require.NotNil(get1) 238 for cmpResult, cases := range cmpCase { 239 for _, pair := range cases { 240 eq := expression.NewRegexp(get0, get1) 241 require.NotNil(eq) 242 require.Equal(types.Boolean, eq.Type()) 243 244 row := sql.NewRow(pair[0], pair[1]) 245 require.NotNil(row) 246 cmp := eval(t, eq, row) 247 if cmpResult == testRegexp { 248 require.Equal(true, cmp) 249 } else if cmpResult == testNil { 250 require.Nil(cmp) 251 } else { 252 require.Equal(false, cmp) 253 } 254 } 255 } 256 } 257 } 258 259 func TestInvalidRegexp(t *testing.T) { 260 t.Helper() 261 require := require.New(t) 262 263 col1 := expression.NewGetField(0, types.LongText, "col1", true) 264 invalid := expression.NewLiteral("*col1", types.LongText) 265 r := expression.NewRegexp(col1, invalid) 266 row := sql.NewRow("col1") 267 268 _, err := r.Eval(sql.NewEmptyContext(), row) 269 require.Error(err) 270 } 271 272 func eval(t *testing.T, e sql.Expression, row sql.Row) interface{} { 273 t.Helper() 274 v, err := e.Eval(sql.NewEmptyContext(), row) 275 require.NoError(t, err) 276 return v 277 }