github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/substring_test.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package function
    16  
    17  import (
    18  	"testing"
    19  
    20  	"github.com/stretchr/testify/require"
    21  
    22  	"github.com/dolthub/go-mysql-server/sql"
    23  	"github.com/dolthub/go-mysql-server/sql/expression"
    24  	"github.com/dolthub/go-mysql-server/sql/types"
    25  )
    26  
    27  func TestSubstring(t *testing.T) {
    28  	f, err := NewSubstring(
    29  		expression.NewGetField(0, types.LongText, "str", true),
    30  		expression.NewGetField(1, types.Int32, "start", false),
    31  		expression.NewGetField(2, types.Int64, "len", false),
    32  	)
    33  	require.NoError(t, err)
    34  
    35  	testCases := []struct {
    36  		name     string
    37  		row      sql.Row
    38  		expected interface{}
    39  		err      bool
    40  	}{
    41  		{"null string", sql.NewRow(nil, 1, 1), nil, false},
    42  		{"null start", sql.NewRow("foo", nil, 1), nil, false},
    43  		{"null len", sql.NewRow("foo", 1, nil), nil, false},
    44  		{"negative start", sql.NewRow("foo", -1, 10), "o", false},
    45  		{"negative length", sql.NewRow("foo", 1, -1), "", false},
    46  		{"length 0", sql.NewRow("foo", 1, 0), "", false},
    47  		{"start bigger than string", sql.NewRow("foo", 50, 10), "", false},
    48  		{"negative start bigger than string", sql.NewRow("foo", -4, 10), "", false},
    49  		{"length overflows", sql.NewRow("foo", 2, 10), "oo", false},
    50  		{"length overflows by one", sql.NewRow("foo", 2, 2), "oo", false},
    51  		{"substring contained", sql.NewRow("foo", 1, 2), "fo", false},
    52  		{"negative start until str beginning", sql.NewRow("foo", -3, 2), "fo", false},
    53  	}
    54  
    55  	for _, tt := range testCases {
    56  		t.Run(tt.name, func(t *testing.T) {
    57  			require := require.New(t)
    58  			ctx := sql.NewEmptyContext()
    59  
    60  			v, err := f.Eval(ctx, tt.row)
    61  			if tt.err {
    62  				require.Error(err)
    63  			} else {
    64  				require.NoError(err)
    65  				require.Equal(tt.expected, v)
    66  			}
    67  		})
    68  	}
    69  }
    70  
    71  func TestSubstringIndex(t *testing.T) {
    72  	f := NewSubstringIndex(
    73  		expression.NewGetField(0, types.LongText, "str", true),
    74  		expression.NewGetField(1, types.LongText, "delim", true),
    75  		expression.NewGetField(2, types.Int64, "count", false),
    76  	)
    77  	testCases := []struct {
    78  		name     string
    79  		row      sql.Row
    80  		expected interface{}
    81  		err      bool
    82  	}{
    83  		{"null string", sql.NewRow(nil, ".", 1), nil, false},
    84  		{"null delim", sql.NewRow("foo", nil, 1), nil, false},
    85  		{"null count", sql.NewRow("foo", 1, nil), nil, false},
    86  		{"positive count", sql.NewRow("a.b.c.d.e.f", ".", 2), "a.b", false},
    87  		{"negative count", sql.NewRow("a.b.c.d.e.f", ".", -2), "e.f", false},
    88  		{"count 0", sql.NewRow("a.b.c", ".", 0), "", false},
    89  		{"long delim", sql.NewRow("a.b.c.d.e.f", "..", 5), "a.b.c.d.e.f", false},
    90  		{"count > len", sql.NewRow("a.b.c", ".", 10), "a.b.c", false},
    91  		{"-count > -len", sql.NewRow("a.b.c", ".", -10), "a.b.c", false},
    92  		{"remove suffix", sql.NewRow("source{d}", "{d}", 1), "source", false},
    93  		{"remove suffix with negtive count", sql.NewRow("source{d}", "{d}", -1), "", false},
    94  		{"wrong count type", sql.NewRow("", "", "foo"), "", true},
    95  	}
    96  
    97  	for _, tt := range testCases {
    98  		t.Run(tt.name, func(t *testing.T) {
    99  			require := require.New(t)
   100  			ctx := sql.NewEmptyContext()
   101  
   102  			v, err := f.Eval(ctx, tt.row)
   103  			if tt.err {
   104  				require.Error(err)
   105  			} else {
   106  				require.NoError(err)
   107  				require.Equal(tt.expected, v)
   108  			}
   109  		})
   110  	}
   111  }
   112  
   113  func TestInstr(t *testing.T) {
   114  	f := NewInstr(
   115  		expression.NewGetField(0, types.LongText, "str", true),
   116  		expression.NewGetField(1, types.LongText, "substr", false),
   117  	)
   118  
   119  	testCases := []struct {
   120  		name     string
   121  		row      sql.Row
   122  		expected interface{}
   123  		err      bool
   124  	}{
   125  		{"both null", sql.NewRow(nil, nil), nil, false},
   126  		{"null string", sql.NewRow(nil, "hello"), nil, false},
   127  		{"null substr", sql.NewRow("foo", nil), nil, false},
   128  		{"total match", sql.NewRow("foo", "foo"), 1, false},
   129  		{"midword match", sql.NewRow("foobar", "bar"), 4, false},
   130  		{"non match", sql.NewRow("foo", "bar"), 0, false},
   131  		{"substr bigger than string", sql.NewRow("foo", "foobar"), 0, false},
   132  		{"multiple matches", sql.NewRow("bobobo", "bo"), 1, false},
   133  		{"bad string", sql.NewRow(1, "hello"), 0, true},
   134  		{"bad substr", sql.NewRow("foo", 1), 0, true},
   135  	}
   136  
   137  	for _, tt := range testCases {
   138  		t.Run(tt.name, func(t *testing.T) {
   139  			require := require.New(t)
   140  			ctx := sql.NewEmptyContext()
   141  
   142  			v, err := f.Eval(ctx, tt.row)
   143  			if tt.err {
   144  				require.Error(err)
   145  			} else {
   146  				require.NoError(err)
   147  				var expected interface{}
   148  				if i, ok := tt.expected.(int); ok {
   149  					expected = int64(i)
   150  				}
   151  				require.Equal(expected, v)
   152  			}
   153  		})
   154  	}
   155  }
   156  
   157  func TestLeft(t *testing.T) {
   158  	f := NewLeft(
   159  		expression.NewGetField(0, types.LongText, "str", true),
   160  		expression.NewGetField(1, types.Int64, "len", false),
   161  	)
   162  
   163  	testCases := []struct {
   164  		name     string
   165  		row      sql.Row
   166  		expected interface{}
   167  		err      bool
   168  	}{
   169  		{"both null", sql.NewRow(nil, nil), nil, false},
   170  		{"null string", sql.NewRow(nil, 1), nil, false},
   171  		{"null len", sql.NewRow("foo", nil), nil, false},
   172  		{"len == string.len", sql.NewRow("foo", 3), "foo", false},
   173  		{"len > string.len", sql.NewRow("foo", 10), "foo", false},
   174  		{"len == 0", sql.NewRow("foo", 0), "", false},
   175  		{"len < 0", sql.NewRow("foo", -1), "", false},
   176  		{"len < string.len", sql.NewRow("foo", 2), "fo", false},
   177  		{"bad string type", sql.NewRow(1, 1), "", true},
   178  		{"bad len type", sql.NewRow("hello", "hello"), "", true},
   179  	}
   180  
   181  	for _, tt := range testCases {
   182  		t.Run(tt.name, func(t *testing.T) {
   183  			require := require.New(t)
   184  			ctx := sql.NewEmptyContext()
   185  
   186  			v, err := f.Eval(ctx, tt.row)
   187  			if tt.err {
   188  				require.Error(err)
   189  			} else {
   190  				require.NoError(err)
   191  				require.Equal(tt.expected, v)
   192  			}
   193  		})
   194  	}
   195  }
   196  
   197  func TestRight(t *testing.T) {
   198  	f := NewRight(
   199  		expression.NewGetField(0, types.LongText, "str", true),
   200  		expression.NewGetField(1, types.Int64, "len", false),
   201  	)
   202  
   203  	testCases := []struct {
   204  		name     string
   205  		row      sql.Row
   206  		expected interface{}
   207  		err      bool
   208  	}{
   209  		{"both null", sql.NewRow(nil, nil), nil, false},
   210  		{"null string", sql.NewRow(nil, 1), nil, false},
   211  		{"null len", sql.NewRow("foo", nil), nil, false},
   212  		{"len == string.len", sql.NewRow("foo", 3), "foo", false},
   213  		{"len > string.len", sql.NewRow("foo", 10), "foo", false},
   214  		{"len == 0", sql.NewRow("foo", 0), "", false},
   215  		{"len < 0", sql.NewRow("foo", -1), "", false},
   216  		{"len < string.len", sql.NewRow("foo", 2), "oo", false},
   217  		{"bad string type", sql.NewRow(1, 1), "", true},
   218  		{"bad len type", sql.NewRow("hello", "hello"), "", true},
   219  	}
   220  
   221  	for _, tt := range testCases {
   222  		t.Run(tt.name, func(t *testing.T) {
   223  			require := require.New(t)
   224  			ctx := sql.NewEmptyContext()
   225  
   226  			v, err := f.Eval(ctx, tt.row)
   227  			if tt.err {
   228  				require.Error(err)
   229  			} else {
   230  				require.NoError(err)
   231  				require.Equal(tt.expected, v)
   232  			}
   233  		})
   234  	}
   235  }