github.com/dolthub/go-mysql-server@v0.18.0/sql/types/set_test.go (about)

     1  // Copyright 2022 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package types
    16  
    17  import (
    18  	"fmt"
    19  	"reflect"
    20  	"strconv"
    21  	"testing"
    22  	"time"
    23  
    24  	"github.com/stretchr/testify/assert"
    25  	"github.com/stretchr/testify/require"
    26  
    27  	"github.com/dolthub/go-mysql-server/sql"
    28  )
    29  
    30  func TestSetCompare(t *testing.T) {
    31  	tests := []struct {
    32  		vals        []string
    33  		collation   sql.CollationID
    34  		val1        interface{}
    35  		val2        interface{}
    36  		expectedCmp int
    37  	}{
    38  		{[]string{"one", "two"}, sql.Collation_Default, nil, 1, 1},
    39  		{[]string{"one", "two"}, sql.Collation_Default, "one", nil, -1},
    40  		{[]string{"one", "two"}, sql.Collation_Default, nil, nil, 0},
    41  		{[]string{"one", "two"}, sql.Collation_Default, 0, "one", -1},
    42  		{[]string{"one", "two"}, sql.Collation_Default, 1, "two", -1},
    43  		{[]string{"one", "two"}, sql.Collation_Default, 2, []byte("one"), 1},
    44  		{[]string{"one", "two"}, sql.Collation_Default, "one", "", 1},
    45  		{[]string{"one", "two"}, sql.Collation_Default, "one", 1, 0},
    46  		{[]string{"one", "two"}, sql.Collation_Default, "one", "two", -1},
    47  		{[]string{"two", "one"}, sql.Collation_binary, "two", "one", -1},
    48  		{[]string{"one", "two"}, sql.Collation_Default, 3, "one,two", 0},
    49  		{[]string{"one", "two"}, sql.Collation_Default, "two,one,two", "one,two", 0},
    50  		{[]string{"one", "two"}, sql.Collation_Default, "two", "", 1},
    51  		{[]string{"one", "two"}, sql.Collation_Default, "one,two", "two", 1},
    52  		{[]string{"a", "b", "c"}, sql.Collation_Default, "a,b", "b,c", -1},
    53  		{[]string{"a", "b", "c"}, sql.Collation_Default, "a,b,c", "c,c,b", 1},
    54  	}
    55  
    56  	for _, test := range tests {
    57  		t.Run(fmt.Sprintf("%v %v %v %v", test.vals, test.collation, test.val1, test.val2), func(t *testing.T) {
    58  			typ := MustCreateSetType(test.vals, test.collation)
    59  			cmp, err := typ.Compare(test.val1, test.val2)
    60  			require.NoError(t, err)
    61  			assert.Equal(t, test.expectedCmp, cmp)
    62  		})
    63  	}
    64  }
    65  
    66  func TestSetCompareErrors(t *testing.T) {
    67  	tests := []struct {
    68  		vals      []string
    69  		collation sql.CollationID
    70  		val1      interface{}
    71  		val2      interface{}
    72  	}{
    73  		{[]string{"one", "two"}, sql.Collation_Default, "three", "two"},
    74  		{[]string{"one", "two"}, sql.Collation_Default, time.Date(2019, 12, 12, 12, 12, 12, 0, time.UTC), []byte("one")},
    75  	}
    76  
    77  	for _, test := range tests {
    78  		t.Run(fmt.Sprintf("%v %v %v %v", test.vals, test.collation, test.val1, test.val2), func(t *testing.T) {
    79  			typ := MustCreateSetType(test.vals, test.collation)
    80  			_, err := typ.Compare(test.val1, test.val2)
    81  			require.Error(t, err)
    82  		})
    83  	}
    84  }
    85  
    86  func TestSetCreate(t *testing.T) {
    87  	tests := []struct {
    88  		vals         []string
    89  		collation    sql.CollationID
    90  		expectedVals map[string]uint64
    91  		expectedErr  bool
    92  	}{
    93  		{[]string{"one"}, sql.Collation_Default,
    94  			map[string]uint64{"one": 1}, false},
    95  		{[]string{" one ", "  two  "}, sql.Collation_Default,
    96  			map[string]uint64{" one": 1, "  two": 2}, false},
    97  		{[]string{"a", "b", "c"}, sql.Collation_Default,
    98  			map[string]uint64{"a": 1, "b": 2, "c": 4}, false},
    99  		{[]string{"one", "one "}, sql.Collation_binary, map[string]uint64{"one": 1, "one ": 2}, false},
   100  		{[]string{"one", "One"}, sql.Collation_binary, map[string]uint64{"one": 1, "One": 2}, false},
   101  
   102  		{[]string{}, sql.Collation_Default, nil, true},
   103  		{[]string{"one", "one"}, sql.Collation_Default, nil, true},
   104  		{[]string{"one", "one"}, sql.Collation_binary, nil, true},
   105  		{[]string{"one", "One"}, sql.Collation_utf8mb4_general_ci, nil, true},
   106  		{[]string{"one", "one "}, sql.Collation_Default, nil, true},
   107  		{[]string{"one", "two,"}, sql.Collation_Default, nil, true},
   108  	}
   109  
   110  	for _, test := range tests {
   111  		t.Run(fmt.Sprintf("%v %v", test.vals, test.collation), func(t *testing.T) {
   112  			typ, err := CreateSetType(test.vals, test.collation)
   113  			if test.expectedErr {
   114  				assert.Error(t, err)
   115  			} else {
   116  				require.NoError(t, err)
   117  				concreteType, ok := typ.(SetType)
   118  				require.True(t, ok)
   119  				assert.True(t, test.collation.Equals(typ.Collation()))
   120  				for val, bit := range test.expectedVals {
   121  					bitField, err := concreteType.convertStringToBitField(val)
   122  					if assert.NoError(t, err) {
   123  						assert.Equal(t, bit, bitField)
   124  					}
   125  					str, err := concreteType.convertBitFieldToString(bit)
   126  					if assert.NoError(t, err) {
   127  						assert.Equal(t, val, str)
   128  					}
   129  				}
   130  			}
   131  		})
   132  	}
   133  }
   134  
   135  func TestSetCreateTooLarge(t *testing.T) {
   136  	vals := make([]string, 65)
   137  	for i := range vals {
   138  		vals[i] = strconv.Itoa(i)
   139  	}
   140  	_, err := CreateSetType(vals, sql.Collation_Default)
   141  	require.Error(t, err)
   142  }
   143  
   144  func TestSetConvert(t *testing.T) {
   145  	tests := []struct {
   146  		vals        []string
   147  		collation   sql.CollationID
   148  		val         interface{}
   149  		expectedVal interface{}
   150  		expectedErr bool
   151  	}{
   152  		{[]string{"one", "two"}, sql.Collation_Default, nil, nil, false},
   153  		{[]string{"one", "two"}, sql.Collation_Default, "", "", false},
   154  		{[]string{"one", "two"}, sql.Collation_Default, int(0), "", false},
   155  		{[]string{"one", "two"}, sql.Collation_Default, int8(2), "two", false},
   156  		{[]string{"one", "two"}, sql.Collation_Default, int16(1), "one", false},
   157  		{[]string{"one", "two"}, sql.Collation_binary, int32(2), "two", false},
   158  		{[]string{"one", "two"}, sql.Collation_Default, int64(1), "one", false},
   159  		{[]string{"one", "two"}, sql.Collation_Default, uint(2), "two", false},
   160  		{[]string{"one", "two"}, sql.Collation_binary, uint8(1), "one", false},
   161  		{[]string{"one", "two"}, sql.Collation_Default, uint16(2), "two", false},
   162  		{[]string{"one", "two"}, sql.Collation_binary, uint32(3), "one,two", false},
   163  		{[]string{"one", "two"}, sql.Collation_Default, uint64(2), "two", false},
   164  		{[]string{"one", "two"}, sql.Collation_Default, "one", "one", false},
   165  		{[]string{"one", "two"}, sql.Collation_Default, []byte("two"), "two", false},
   166  		{[]string{"one", "two"}, sql.Collation_Default, "one,two", "one,two", false},
   167  		{[]string{"one", "two"}, sql.Collation_binary, "two,one", "one,two", false},
   168  		{[]string{"one", "two"}, sql.Collation_Default, "one,two,one", "one,two", false},
   169  		{[]string{"one", "two"}, sql.Collation_binary, "two,one,two", "one,two", false},
   170  		{[]string{"one", "two"}, sql.Collation_Default, "two,one,two", "one,two", false},
   171  		{[]string{"a", "b", "c"}, sql.Collation_Default, "b,c  ,a", "a,b,c", false},
   172  		{[]string{"one", "two"}, sql.Collation_utf8mb4_general_ci, "ONE", "one", false},
   173  		{[]string{"ONE", "two"}, sql.Collation_utf8mb4_general_ci, "one", "ONE", false},
   174  		{[]string{"", "one", "two"}, sql.Collation_Default, "", "", false},
   175  		{[]string{"", "one", "two"}, sql.Collation_Default, ",one,two", ",one,two", false},
   176  		{[]string{"", "one", "two"}, sql.Collation_Default, "one,,two", ",one,two", false},
   177  
   178  		{[]string{"one", "two"}, sql.Collation_Default, ",one,two", nil, true},
   179  		{[]string{"one", "two"}, sql.Collation_Default, 4, nil, true},
   180  		{[]string{"one", "two"}, sql.Collation_Default, "three", nil, true},
   181  		{[]string{"one", "two"}, sql.Collation_Default, "one,two,three", nil, true},
   182  		{[]string{"a", "b", "c"}, sql.Collation_binary, "b,c  ,a", nil, true},
   183  		{[]string{"one", "two"}, sql.Collation_binary, "ONE", nil, true},
   184  		{[]string{"one", "two"}, sql.Collation_Default, time.Date(2019, 12, 12, 12, 12, 12, 0, time.UTC), nil, true},
   185  	}
   186  
   187  	for _, test := range tests {
   188  		t.Run(fmt.Sprintf("%v | %v | %v", test.vals, test.collation, test.val), func(t *testing.T) {
   189  			typ := MustCreateSetType(test.vals, test.collation)
   190  			val, _, err := typ.Convert(test.val)
   191  			if test.expectedErr {
   192  				assert.Error(t, err)
   193  			} else {
   194  				require.NoError(t, err)
   195  				res, err := typ.Compare(test.expectedVal, val)
   196  				require.NoError(t, err)
   197  				assert.Equal(t, 0, res)
   198  				if val != nil {
   199  					assert.Equal(t, typ.ValueType(), reflect.TypeOf(val))
   200  				}
   201  			}
   202  		})
   203  	}
   204  }
   205  
   206  func TestSetMarshalMax(t *testing.T) {
   207  	vals := make([]string, 64)
   208  	for i := range vals {
   209  		vals[i] = strconv.Itoa(i)
   210  	}
   211  	typ, err := CreateSetType(vals, sql.Collation_Default)
   212  	require.NoError(t, err)
   213  
   214  	tests := []string{
   215  		"",
   216  		"1",
   217  		"1,2",
   218  		"0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63",
   219  	}
   220  
   221  	for _, test := range tests {
   222  		t.Run(fmt.Sprintf("%v", test), func(t *testing.T) {
   223  			bits, _, err := typ.Convert(test)
   224  			require.NoError(t, err)
   225  			res1, err := typ.BitsToString(bits.(uint64))
   226  			require.NoError(t, err)
   227  			require.Equal(t, test, res1)
   228  			bits2, _, err := typ.Convert(bits)
   229  			require.NoError(t, err)
   230  			res2, err := typ.BitsToString(bits2.(uint64))
   231  			require.NoError(t, err)
   232  			require.Equal(t, test, res2)
   233  		})
   234  	}
   235  }
   236  
   237  func TestSetString(t *testing.T) {
   238  	tests := []struct {
   239  		vals        []string
   240  		collation   sql.CollationID
   241  		expectedStr string
   242  	}{
   243  		{[]string{"one"}, sql.Collation_Default, "set('one')"},
   244  		{[]string{"مرحبا", "こんにちは"}, sql.Collation_Default, "set('مرحبا','こんにちは')"},
   245  		{[]string{" hi ", "  lo  "}, sql.Collation_Default, "set(' hi','  lo')"},
   246  		{[]string{" hi ", "  lo  "}, sql.Collation_binary, "set(' hi ','  lo  ') CHARACTER SET binary COLLATE binary"},
   247  		{[]string{"a"}, sql.Collation_Default.CharacterSet().BinaryCollation(),
   248  			fmt.Sprintf("set('a') COLLATE %v", sql.Collation_Default.CharacterSet().BinaryCollation())},
   249  	}
   250  
   251  	for _, test := range tests {
   252  		t.Run(fmt.Sprintf("%v %v", test.vals, test.collation), func(t *testing.T) {
   253  			typ := MustCreateSetType(test.vals, test.collation)
   254  			assert.Equal(t, test.expectedStr, typ.String())
   255  		})
   256  	}
   257  }
   258  
   259  func TestSetZero(t *testing.T) {
   260  	setType := MustCreateSetType([]string{"a", "b"}, sql.Collation_Default)
   261  	require.Equal(t, uint64(0), setType.Zero())
   262  }
   263  
   264  func TestSetConvertToString(t *testing.T) {
   265  	tests := []struct {
   266  		vals        []string
   267  		collation   sql.CollationID
   268  		bit         uint64
   269  		expectedStr string
   270  	}{
   271  		{[]string{"", "a", "b", "c"}, sql.Collation_Default, 15, "a,b,c"},
   272  		{[]string{"", "a", "b", "c"}, sql.Collation_Default, 14, "a,b,c"},
   273  	}
   274  
   275  	for _, test := range tests {
   276  		t.Run(fmt.Sprintf("%v %v", test.vals, test.collation), func(t *testing.T) {
   277  			typ, err := CreateSetType(test.vals, test.collation)
   278  			require.NoError(t, err)
   279  			concreteType, ok := typ.(SetType)
   280  			require.True(t, ok)
   281  			assert.True(t, test.collation.Equals(typ.Collation()))
   282  			str, err := concreteType.convertBitFieldToString(test.bit)
   283  			if assert.NoError(t, err) {
   284  				assert.Equal(t, test.expectedStr, str)
   285  			}
   286  		})
   287  	}
   288  }