github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/regexp_replace_test.go (about) 1 // Copyright 2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package function 16 17 import ( 18 "testing" 19 20 "github.com/stretchr/testify/require" 21 22 "github.com/dolthub/go-mysql-server/sql" 23 "github.com/dolthub/go-mysql-server/sql/expression" 24 "github.com/dolthub/go-mysql-server/sql/types" 25 ) 26 27 func TestRegexpReplaceInvalidArgNumber(t *testing.T) { 28 _, err := NewRegexpReplace() 29 require.Error(t, err) 30 31 _, err = NewRegexpReplace( 32 expression.NewGetField(0, types.LongText, "str", true), 33 ) 34 require.Error(t, err) 35 36 _, err = NewRegexpReplace( 37 expression.NewGetField(0, types.LongText, "str", true), 38 expression.NewGetField(1, types.LongText, "pattern", true), 39 ) 40 require.Error(t, err) 41 42 _, err = NewRegexpReplace( 43 expression.NewGetField(0, types.LongText, "str", true), 44 expression.NewGetField(1, types.LongText, "pattern", true), 45 expression.NewGetField(2, types.LongText, "replaceStr", true), 46 expression.NewGetField(3, types.LongText, "position", true), 47 expression.NewGetField(4, types.LongText, "occurrence", true), 48 expression.NewGetField(5, types.LongText, "flags", true), 49 expression.NewGetField(6, types.LongText, "???", true), 50 ) 51 require.Error(t, err) 52 } 53 54 func TestRegexpReplace(t *testing.T) { 55 f, err := NewRegexpReplace( 56 expression.NewGetField(0, types.LongText, "str", true), 57 expression.NewGetField(1, types.LongText, "pattern", true), 58 expression.NewGetField(2, types.LongText, "replaceStr", true), 59 ) 60 require.NoError(t, err) 61 62 testCases := []struct { 63 name string 64 row sql.Row 65 expected interface{} 66 err bool 67 }{ 68 { 69 "nil str", 70 sql.NewRow(nil, `[a-z]`, "X"), 71 nil, 72 false, 73 }, 74 { 75 "nil pattern", 76 sql.NewRow("abc def ghi", nil, "X"), 77 nil, 78 false, 79 }, 80 { 81 "nil replaceStr", 82 sql.NewRow("abc def ghi", `[a-z]`, nil), 83 nil, 84 false, 85 }, 86 { 87 "empty str", 88 sql.NewRow("", `[a-z]`, "a"), 89 "", 90 false, 91 }, 92 { 93 "empty pattern", 94 sql.NewRow("abc def ghi", ``, nil), 95 nil, 96 true, 97 }, 98 { 99 "empty replaceStr", 100 sql.NewRow("abc def ghi", `[a-z]`, ""), 101 " ", 102 false, 103 }, 104 { 105 "valid case", 106 sql.NewRow("abc def ghi", `[a-z]`, "X"), 107 "XXX XXX XXX", 108 false, 109 }, 110 } 111 112 for _, tt := range testCases { 113 t.Run(tt.name, func(t *testing.T) { 114 require := require.New(t) 115 ctx := sql.NewEmptyContext() 116 117 val, err := f.Eval(ctx, tt.row) 118 if tt.err { 119 require.Error(err) 120 } else { 121 require.NoError(err) 122 require.Equal(tt.expected, val) 123 } 124 }) 125 } 126 } 127 128 func TestRegexpReplaceWithPosition(t *testing.T) { 129 f, err := NewRegexpReplace( 130 expression.NewGetField(0, types.LongText, "str", true), 131 expression.NewGetField(1, types.LongText, "pattern", true), 132 expression.NewGetField(2, types.LongText, "replaceStr", true), 133 expression.NewGetField(3, types.LongText, "position", true), 134 ) 135 require.NoError(t, err) 136 137 testCases := []struct { 138 name string 139 row sql.Row 140 expected interface{} 141 err bool 142 }{ 143 { 144 "nil position", 145 sql.NewRow("abc def ghi", `[a-z]`, "X", nil), 146 nil, 147 false, 148 }, 149 { 150 "negative position", 151 sql.NewRow("abc def ghi", `[a-z]`, "X", -1), 152 nil, 153 true, 154 }, 155 { 156 "zero position", 157 sql.NewRow("abc def ghi", `[a-z]`, "X", 0), 158 nil, 159 true, 160 }, 161 { 162 "too large position", 163 sql.NewRow("abc def ghi", `[a-z]`, "X", 1000), 164 nil, 165 true, 166 }, 167 { 168 "string type position", 169 sql.NewRow("abc def ghi", `[a-z]`, "X", "1"), 170 "XXX XXX XXX", 171 false, 172 }, 173 { 174 "valid case", 175 sql.NewRow("abc def ghi", `[a-z]`, "X", 1), 176 "XXX XXX XXX", 177 false, 178 }, 179 { 180 "valid case", 181 sql.NewRow("abc def ghi", `[a-z]`, "X", 2), 182 "aXX XXX XXX", 183 false, 184 }, 185 { 186 "valid case", 187 sql.NewRow("abc def ghi", `[a-z]`, "X", 5), 188 "abc XXX XXX", 189 false, 190 }, 191 } 192 193 for _, tt := range testCases { 194 t.Run(tt.name, func(t *testing.T) { 195 require := require.New(t) 196 ctx := sql.NewEmptyContext() 197 198 val, err := f.Eval(ctx, tt.row) 199 if tt.err { 200 require.Error(err) 201 } else { 202 require.NoError(err) 203 require.Equal(tt.expected, val) 204 } 205 }) 206 } 207 } 208 209 func TestRegexpReplaceWithOccurrence(t *testing.T) { 210 f, err := NewRegexpReplace( 211 expression.NewGetField(0, types.LongText, "str", true), 212 expression.NewGetField(1, types.LongText, "pattern", true), 213 expression.NewGetField(2, types.LongText, "replaceStr", true), 214 expression.NewGetField(3, types.LongText, "position", true), 215 expression.NewGetField(4, types.LongText, "occurrence", true), 216 ) 217 require.NoError(t, err) 218 219 testCases := []struct { 220 name string 221 row sql.Row 222 expected interface{} 223 err bool 224 }{ 225 { 226 "nil occurrence", 227 sql.NewRow("abc def ghi", `[a-z]`, "X", 1, nil), 228 nil, 229 false, 230 }, 231 { 232 "string type occurrence", 233 sql.NewRow("abc def ghi", `[a-z]`, "X", 1, "0"), 234 "XXX XXX XXX", 235 false, 236 }, 237 { 238 "negative occurrence", 239 sql.NewRow("abc def ghi", `[a-z]`, "X", 1, -1), 240 "Xbc def ghi", 241 false, 242 }, 243 { 244 "zero occurrence", 245 sql.NewRow("abc def ghi", `[a-z]`, "X", 1, 0), 246 "XXX XXX XXX", 247 false, 248 }, 249 { 250 "one occurrence", 251 sql.NewRow("abc def ghi", `[a-z]`, "X", 1, 1), 252 "Xbc def ghi", 253 false, 254 }, 255 { 256 "positive occurrence", 257 sql.NewRow("abc def ghi", `[a-z]`, "X", 1, 4), 258 "abc Xef ghi", 259 false, 260 }, 261 { 262 "too large occurrence", 263 sql.NewRow("abc def ghi", `[a-z]`, "X", 1, 1000), 264 "abc def ghi", 265 false, 266 }, 267 { 268 "position and occurrence", 269 sql.NewRow("abc def ghi", `[a-z]`, "X", 5, 4), 270 "abc def Xhi", 271 false, 272 }, 273 } 274 275 for _, tt := range testCases { 276 t.Run(tt.name, func(t *testing.T) { 277 require := require.New(t) 278 ctx := sql.NewEmptyContext() 279 280 val, err := f.Eval(ctx, tt.row) 281 if tt.err { 282 require.Error(err) 283 } else { 284 require.NoError(err) 285 require.Equal(tt.expected, val) 286 } 287 }) 288 } 289 } 290 291 func TestRegexpReplaceWithFlags(t *testing.T) { 292 f, err := NewRegexpReplace( 293 expression.NewGetField(0, types.LongText, "str", true), 294 expression.NewGetField(1, types.LongText, "pattern", true), 295 expression.NewGetField(2, types.LongText, "replaceStr", true), 296 expression.NewGetField(3, types.LongText, "position", true), 297 expression.NewGetField(4, types.LongText, "occurrence", true), 298 expression.NewGetField(5, types.LongText, "flags", true), 299 ) 300 require.NoError(t, err) 301 302 testCases := []struct { 303 name string 304 row sql.Row 305 expected interface{} 306 err bool 307 }{ 308 { 309 "nil flags", 310 sql.NewRow("abc def ghi", `[a-z]`, "X", 1, 0, nil), 311 nil, 312 false, 313 }, 314 { 315 "bad flags", 316 sql.NewRow("abc def ghi", `[a-z]`, "X", 1, 0, "a"), 317 nil, 318 true, 319 }, 320 { 321 "case-sensitive flags", 322 sql.NewRow("abc DEF ghi", `[a-z]`, "X", 1, 0, "c"), 323 "XXX DEF XXX", 324 false, 325 }, 326 { 327 "case-insensitive flags", 328 sql.NewRow("abc DEF ghi", `[a-z]`, "X", 1, 0, "i"), 329 "XXX XXX XXX", 330 false, 331 }, 332 { 333 "multiline flags", 334 sql.NewRow("abc\r\ndef\r\nghi", `^[a-z].*$`, "X", 1, 0, "m"), 335 "X\r\nX\r\nX", 336 false, 337 }, 338 { 339 "insensitive and multiline flags", 340 sql.NewRow("abc\r\nDEF\r\nghi", `^[a-z].*$`, "X", 1, 0, "im"), 341 "X\r\nX\r\nX", 342 false, 343 }, 344 { 345 "sensitive and multiline flags", 346 sql.NewRow("abc\r\nDEF\r\nghi", `^[a-z].*$`, "X", 1, 0, "cm"), 347 "X\r\nDEF\r\nX", 348 false, 349 }, 350 { 351 "all flags", 352 sql.NewRow("abc\r\nDEF\r\nghi", `^[a-z].*$`, "X", 1, 0, "icm"), 353 "X\r\nDEF\r\nX", 354 false, 355 }, 356 { 357 "repeated flags", 358 sql.NewRow("abc DEF ghi", `[a-z]`, "X", 1, 0, "iiiiiicccc"), 359 "XXX DEF XXX", 360 false, 361 }, 362 } 363 364 for _, tt := range testCases { 365 t.Run(tt.name, func(t *testing.T) { 366 require := require.New(t) 367 ctx := sql.NewEmptyContext() 368 369 val, err := f.Eval(ctx, tt.row) 370 if tt.err { 371 require.Error(err) 372 } else { 373 require.NoError(err) 374 require.Equal(tt.expected, val) 375 } 376 }) 377 } 378 }