github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/regexp_replace_test.go (about)

     1  // Copyright 2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package function
    16  
    17  import (
    18  	"testing"
    19  
    20  	"github.com/stretchr/testify/require"
    21  
    22  	"github.com/dolthub/go-mysql-server/sql"
    23  	"github.com/dolthub/go-mysql-server/sql/expression"
    24  	"github.com/dolthub/go-mysql-server/sql/types"
    25  )
    26  
    27  func TestRegexpReplaceInvalidArgNumber(t *testing.T) {
    28  	_, err := NewRegexpReplace()
    29  	require.Error(t, err)
    30  
    31  	_, err = NewRegexpReplace(
    32  		expression.NewGetField(0, types.LongText, "str", true),
    33  	)
    34  	require.Error(t, err)
    35  
    36  	_, err = NewRegexpReplace(
    37  		expression.NewGetField(0, types.LongText, "str", true),
    38  		expression.NewGetField(1, types.LongText, "pattern", true),
    39  	)
    40  	require.Error(t, err)
    41  
    42  	_, err = NewRegexpReplace(
    43  		expression.NewGetField(0, types.LongText, "str", true),
    44  		expression.NewGetField(1, types.LongText, "pattern", true),
    45  		expression.NewGetField(2, types.LongText, "replaceStr", true),
    46  		expression.NewGetField(3, types.LongText, "position", true),
    47  		expression.NewGetField(4, types.LongText, "occurrence", true),
    48  		expression.NewGetField(5, types.LongText, "flags", true),
    49  		expression.NewGetField(6, types.LongText, "???", true),
    50  	)
    51  	require.Error(t, err)
    52  }
    53  
    54  func TestRegexpReplace(t *testing.T) {
    55  	f, err := NewRegexpReplace(
    56  		expression.NewGetField(0, types.LongText, "str", true),
    57  		expression.NewGetField(1, types.LongText, "pattern", true),
    58  		expression.NewGetField(2, types.LongText, "replaceStr", true),
    59  	)
    60  	require.NoError(t, err)
    61  
    62  	testCases := []struct {
    63  		name     string
    64  		row      sql.Row
    65  		expected interface{}
    66  		err      bool
    67  	}{
    68  		{
    69  			"nil str",
    70  			sql.NewRow(nil, `[a-z]`, "X"),
    71  			nil,
    72  			false,
    73  		},
    74  		{
    75  			"nil pattern",
    76  			sql.NewRow("abc def ghi", nil, "X"),
    77  			nil,
    78  			false,
    79  		},
    80  		{
    81  			"nil replaceStr",
    82  			sql.NewRow("abc def ghi", `[a-z]`, nil),
    83  			nil,
    84  			false,
    85  		},
    86  		{
    87  			"empty str",
    88  			sql.NewRow("", `[a-z]`, "a"),
    89  			"",
    90  			false,
    91  		},
    92  		{
    93  			"empty pattern",
    94  			sql.NewRow("abc def ghi", ``, nil),
    95  			nil,
    96  			true,
    97  		},
    98  		{
    99  			"empty replaceStr",
   100  			sql.NewRow("abc def ghi", `[a-z]`, ""),
   101  			"  ",
   102  			false,
   103  		},
   104  		{
   105  			"valid case",
   106  			sql.NewRow("abc def ghi", `[a-z]`, "X"),
   107  			"XXX XXX XXX",
   108  			false,
   109  		},
   110  	}
   111  
   112  	for _, tt := range testCases {
   113  		t.Run(tt.name, func(t *testing.T) {
   114  			require := require.New(t)
   115  			ctx := sql.NewEmptyContext()
   116  
   117  			val, err := f.Eval(ctx, tt.row)
   118  			if tt.err {
   119  				require.Error(err)
   120  			} else {
   121  				require.NoError(err)
   122  				require.Equal(tt.expected, val)
   123  			}
   124  		})
   125  	}
   126  }
   127  
   128  func TestRegexpReplaceWithPosition(t *testing.T) {
   129  	f, err := NewRegexpReplace(
   130  		expression.NewGetField(0, types.LongText, "str", true),
   131  		expression.NewGetField(1, types.LongText, "pattern", true),
   132  		expression.NewGetField(2, types.LongText, "replaceStr", true),
   133  		expression.NewGetField(3, types.LongText, "position", true),
   134  	)
   135  	require.NoError(t, err)
   136  
   137  	testCases := []struct {
   138  		name     string
   139  		row      sql.Row
   140  		expected interface{}
   141  		err      bool
   142  	}{
   143  		{
   144  			"nil position",
   145  			sql.NewRow("abc def ghi", `[a-z]`, "X", nil),
   146  			nil,
   147  			false,
   148  		},
   149  		{
   150  			"negative position",
   151  			sql.NewRow("abc def ghi", `[a-z]`, "X", -1),
   152  			nil,
   153  			true,
   154  		},
   155  		{
   156  			"zero position",
   157  			sql.NewRow("abc def ghi", `[a-z]`, "X", 0),
   158  			nil,
   159  			true,
   160  		},
   161  		{
   162  			"too large position",
   163  			sql.NewRow("abc def ghi", `[a-z]`, "X", 1000),
   164  			nil,
   165  			true,
   166  		},
   167  		{
   168  			"string type position",
   169  			sql.NewRow("abc def ghi", `[a-z]`, "X", "1"),
   170  			"XXX XXX XXX",
   171  			false,
   172  		},
   173  		{
   174  			"valid case",
   175  			sql.NewRow("abc def ghi", `[a-z]`, "X", 1),
   176  			"XXX XXX XXX",
   177  			false,
   178  		},
   179  		{
   180  			"valid case",
   181  			sql.NewRow("abc def ghi", `[a-z]`, "X", 2),
   182  			"aXX XXX XXX",
   183  			false,
   184  		},
   185  		{
   186  			"valid case",
   187  			sql.NewRow("abc def ghi", `[a-z]`, "X", 5),
   188  			"abc XXX XXX",
   189  			false,
   190  		},
   191  	}
   192  
   193  	for _, tt := range testCases {
   194  		t.Run(tt.name, func(t *testing.T) {
   195  			require := require.New(t)
   196  			ctx := sql.NewEmptyContext()
   197  
   198  			val, err := f.Eval(ctx, tt.row)
   199  			if tt.err {
   200  				require.Error(err)
   201  			} else {
   202  				require.NoError(err)
   203  				require.Equal(tt.expected, val)
   204  			}
   205  		})
   206  	}
   207  }
   208  
   209  func TestRegexpReplaceWithOccurrence(t *testing.T) {
   210  	f, err := NewRegexpReplace(
   211  		expression.NewGetField(0, types.LongText, "str", true),
   212  		expression.NewGetField(1, types.LongText, "pattern", true),
   213  		expression.NewGetField(2, types.LongText, "replaceStr", true),
   214  		expression.NewGetField(3, types.LongText, "position", true),
   215  		expression.NewGetField(4, types.LongText, "occurrence", true),
   216  	)
   217  	require.NoError(t, err)
   218  
   219  	testCases := []struct {
   220  		name     string
   221  		row      sql.Row
   222  		expected interface{}
   223  		err      bool
   224  	}{
   225  		{
   226  			"nil occurrence",
   227  			sql.NewRow("abc def ghi", `[a-z]`, "X", 1, nil),
   228  			nil,
   229  			false,
   230  		},
   231  		{
   232  			"string type occurrence",
   233  			sql.NewRow("abc def ghi", `[a-z]`, "X", 1, "0"),
   234  			"XXX XXX XXX",
   235  			false,
   236  		},
   237  		{
   238  			"negative occurrence",
   239  			sql.NewRow("abc def ghi", `[a-z]`, "X", 1, -1),
   240  			"Xbc def ghi",
   241  			false,
   242  		},
   243  		{
   244  			"zero occurrence",
   245  			sql.NewRow("abc def ghi", `[a-z]`, "X", 1, 0),
   246  			"XXX XXX XXX",
   247  			false,
   248  		},
   249  		{
   250  			"one occurrence",
   251  			sql.NewRow("abc def ghi", `[a-z]`, "X", 1, 1),
   252  			"Xbc def ghi",
   253  			false,
   254  		},
   255  		{
   256  			"positive occurrence",
   257  			sql.NewRow("abc def ghi", `[a-z]`, "X", 1, 4),
   258  			"abc Xef ghi",
   259  			false,
   260  		},
   261  		{
   262  			"too large occurrence",
   263  			sql.NewRow("abc def ghi", `[a-z]`, "X", 1, 1000),
   264  			"abc def ghi",
   265  			false,
   266  		},
   267  		{
   268  			"position and occurrence",
   269  			sql.NewRow("abc def ghi", `[a-z]`, "X", 5, 4),
   270  			"abc def Xhi",
   271  			false,
   272  		},
   273  	}
   274  
   275  	for _, tt := range testCases {
   276  		t.Run(tt.name, func(t *testing.T) {
   277  			require := require.New(t)
   278  			ctx := sql.NewEmptyContext()
   279  
   280  			val, err := f.Eval(ctx, tt.row)
   281  			if tt.err {
   282  				require.Error(err)
   283  			} else {
   284  				require.NoError(err)
   285  				require.Equal(tt.expected, val)
   286  			}
   287  		})
   288  	}
   289  }
   290  
   291  func TestRegexpReplaceWithFlags(t *testing.T) {
   292  	f, err := NewRegexpReplace(
   293  		expression.NewGetField(0, types.LongText, "str", true),
   294  		expression.NewGetField(1, types.LongText, "pattern", true),
   295  		expression.NewGetField(2, types.LongText, "replaceStr", true),
   296  		expression.NewGetField(3, types.LongText, "position", true),
   297  		expression.NewGetField(4, types.LongText, "occurrence", true),
   298  		expression.NewGetField(5, types.LongText, "flags", true),
   299  	)
   300  	require.NoError(t, err)
   301  
   302  	testCases := []struct {
   303  		name     string
   304  		row      sql.Row
   305  		expected interface{}
   306  		err      bool
   307  	}{
   308  		{
   309  			"nil flags",
   310  			sql.NewRow("abc def ghi", `[a-z]`, "X", 1, 0, nil),
   311  			nil,
   312  			false,
   313  		},
   314  		{
   315  			"bad flags",
   316  			sql.NewRow("abc def ghi", `[a-z]`, "X", 1, 0, "a"),
   317  			nil,
   318  			true,
   319  		},
   320  		{
   321  			"case-sensitive flags",
   322  			sql.NewRow("abc DEF ghi", `[a-z]`, "X", 1, 0, "c"),
   323  			"XXX DEF XXX",
   324  			false,
   325  		},
   326  		{
   327  			"case-insensitive flags",
   328  			sql.NewRow("abc DEF ghi", `[a-z]`, "X", 1, 0, "i"),
   329  			"XXX XXX XXX",
   330  			false,
   331  		},
   332  		{
   333  			"multiline flags",
   334  			sql.NewRow("abc\r\ndef\r\nghi", `^[a-z].*$`, "X", 1, 0, "m"),
   335  			"X\r\nX\r\nX",
   336  			false,
   337  		},
   338  		{
   339  			"insensitive and multiline flags",
   340  			sql.NewRow("abc\r\nDEF\r\nghi", `^[a-z].*$`, "X", 1, 0, "im"),
   341  			"X\r\nX\r\nX",
   342  			false,
   343  		},
   344  		{
   345  			"sensitive and multiline flags",
   346  			sql.NewRow("abc\r\nDEF\r\nghi", `^[a-z].*$`, "X", 1, 0, "cm"),
   347  			"X\r\nDEF\r\nX",
   348  			false,
   349  		},
   350  		{
   351  			"all flags",
   352  			sql.NewRow("abc\r\nDEF\r\nghi", `^[a-z].*$`, "X", 1, 0, "icm"),
   353  			"X\r\nDEF\r\nX",
   354  			false,
   355  		},
   356  		{
   357  			"repeated flags",
   358  			sql.NewRow("abc DEF ghi", `[a-z]`, "X", 1, 0, "iiiiiicccc"),
   359  			"XXX DEF XXX",
   360  			false,
   361  		},
   362  	}
   363  
   364  	for _, tt := range testCases {
   365  		t.Run(tt.name, func(t *testing.T) {
   366  			require := require.New(t)
   367  			ctx := sql.NewEmptyContext()
   368  
   369  			val, err := f.Eval(ctx, tt.row)
   370  			if tt.err {
   371  				require.Error(err)
   372  			} else {
   373  				require.NoError(err)
   374  				require.Equal(tt.expected, val)
   375  			}
   376  		})
   377  	}
   378  }