github.com/dolthub/go-mysql-server@v0.18.0/sql/types/number_test.go (about)

     1  // Copyright 2022 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package types
    16  
    17  import (
    18  	"fmt"
    19  	"math"
    20  	"reflect"
    21  	"strconv"
    22  	"testing"
    23  	"time"
    24  
    25  	"github.com/dolthub/vitess/go/sqltypes"
    26  	"github.com/dolthub/vitess/go/vt/proto/query"
    27  	"github.com/stretchr/testify/assert"
    28  	"github.com/stretchr/testify/require"
    29  
    30  	"github.com/dolthub/go-mysql-server/sql"
    31  )
    32  
    33  func TestNumberCompare(t *testing.T) {
    34  	tests := []struct {
    35  		typ         sql.Type
    36  		val1        interface{}
    37  		val2        interface{}
    38  		expectedCmp int
    39  	}{
    40  		{Int8, nil, 0, 1},
    41  		{Uint24, 0, nil, -1},
    42  		{Float64, nil, nil, 0},
    43  
    44  		{Boolean, 0, 1, -1},
    45  		{Int8, -1, 2, -1},
    46  		{Int16, -2, 3, -1},
    47  		{Int24, -3, 4, -1},
    48  		{Int32, -4, 5, -1},
    49  		{Int64, -5, 6, -1},
    50  		{Uint8, 6, 7, -1},
    51  		{Uint16, 7, 8, -1},
    52  		{Uint24, 8, 9, -1},
    53  		{Uint32, 9, 10, -1},
    54  		{Uint64, 10, 11, -1},
    55  		{Float32, -11.1, 12.2, -1},
    56  		{Float64, -12.2, 13.3, -1},
    57  		{Boolean, 0, 0, 0},
    58  		{Int8, 1, 1, 0},
    59  		{Int16, 2, 2, 0},
    60  		{Int24, 3, 3, 0},
    61  		{Int32, 4, 4, 0},
    62  		{Int64, 5, 5, 0},
    63  		{Uint8, 6, 6, 0},
    64  		{Uint16, 7, 7, 0},
    65  		{Uint24, 8, 8, 0},
    66  		{Uint32, 9, 9, 0},
    67  		{Uint64, 10, 10, 0},
    68  		{Float32, 11.1, 11.1, 0},
    69  		{Float64, 12.2, 12.2, 0},
    70  		{Boolean, 1, 0, 1},
    71  		{Int8, 2, -1, 1},
    72  		{Int16, 3, -2, 1},
    73  		{Int24, 4, -3, 1},
    74  		{Int32, 5, -4, 1},
    75  		{Int64, 6, -5, 1},
    76  		{Uint8, 7, 6, 1},
    77  		{Uint16, 8, 7, 1},
    78  		{Uint24, 9, 8, 1},
    79  		{Uint32, 10, 9, 1},
    80  		{Uint64, 11, 10, 1},
    81  		{Float32, 12.2, -11.1, 1},
    82  		{Float64, 13.3, -12.2, 1},
    83  	}
    84  
    85  	for _, test := range tests {
    86  		t.Run(fmt.Sprintf("%v %v %v", test.typ, test.val1, test.val2), func(t *testing.T) {
    87  			cmp, err := test.typ.Compare(test.val1, test.val2)
    88  			require.NoError(t, err)
    89  			assert.Equal(t, test.expectedCmp, cmp)
    90  		})
    91  	}
    92  }
    93  
    94  func TestNumberCreate(t *testing.T) {
    95  	tests := []struct {
    96  		baseType     query.Type
    97  		expectedType NumberTypeImpl_
    98  		expectedErr  bool
    99  	}{
   100  		{sqltypes.Int8, NumberTypeImpl_{sqltypes.Int8, 0}, false},
   101  		{sqltypes.Int16, NumberTypeImpl_{sqltypes.Int16, 0}, false},
   102  		{sqltypes.Int24, NumberTypeImpl_{sqltypes.Int24, 0}, false},
   103  		{sqltypes.Int32, NumberTypeImpl_{sqltypes.Int32, 0}, false},
   104  		{sqltypes.Int64, NumberTypeImpl_{sqltypes.Int64, 0}, false},
   105  		{sqltypes.Uint8, NumberTypeImpl_{sqltypes.Uint8, 0}, false},
   106  		{sqltypes.Uint16, NumberTypeImpl_{sqltypes.Uint16, 0}, false},
   107  		{sqltypes.Uint24, NumberTypeImpl_{sqltypes.Uint24, 0}, false},
   108  		{sqltypes.Uint32, NumberTypeImpl_{sqltypes.Uint32, 0}, false},
   109  		{sqltypes.Uint64, NumberTypeImpl_{sqltypes.Uint64, 0}, false},
   110  		{sqltypes.Float32, NumberTypeImpl_{sqltypes.Float32, 0}, false},
   111  		{sqltypes.Float64, NumberTypeImpl_{sqltypes.Float64, 0}, false},
   112  	}
   113  
   114  	for _, test := range tests {
   115  		t.Run(fmt.Sprintf("%v", test.baseType), func(t *testing.T) {
   116  			typ, err := CreateNumberType(test.baseType)
   117  			if test.expectedErr {
   118  				assert.Error(t, err)
   119  			} else {
   120  				require.NoError(t, err)
   121  				assert.Equal(t, test.expectedType, typ)
   122  			}
   123  		})
   124  	}
   125  }
   126  
   127  func TestNumberCreateInvalidBaseTypes(t *testing.T) {
   128  	tests := []struct {
   129  		baseType     query.Type
   130  		expectedType NumberTypeImpl_
   131  		expectedErr  bool
   132  	}{
   133  		{sqltypes.Binary, NumberTypeImpl_{}, true},
   134  		{sqltypes.Bit, NumberTypeImpl_{}, true},
   135  		{sqltypes.Blob, NumberTypeImpl_{}, true},
   136  		{sqltypes.Char, NumberTypeImpl_{}, true},
   137  		{sqltypes.Date, NumberTypeImpl_{}, true},
   138  		{sqltypes.Datetime, NumberTypeImpl_{}, true},
   139  		{sqltypes.Decimal, NumberTypeImpl_{}, true},
   140  		{sqltypes.Enum, NumberTypeImpl_{}, true},
   141  		{sqltypes.Expression, NumberTypeImpl_{}, true},
   142  		{sqltypes.Geometry, NumberTypeImpl_{}, true},
   143  		{sqltypes.Null, NumberTypeImpl_{}, true},
   144  		{sqltypes.Set, NumberTypeImpl_{}, true},
   145  		{sqltypes.Text, NumberTypeImpl_{}, true},
   146  		{sqltypes.Time, NumberTypeImpl_{}, true},
   147  		{sqltypes.Timestamp, NumberTypeImpl_{}, true},
   148  		{sqltypes.TypeJSON, NumberTypeImpl_{}, true},
   149  		{sqltypes.VarBinary, NumberTypeImpl_{}, true},
   150  		{sqltypes.VarChar, NumberTypeImpl_{}, true},
   151  		{sqltypes.Year, NumberTypeImpl_{}, true},
   152  	}
   153  
   154  	for _, test := range tests {
   155  		t.Run(fmt.Sprintf("%v", test.baseType), func(t *testing.T) {
   156  			typ, err := CreateNumberType(test.baseType)
   157  			if test.expectedErr {
   158  				assert.Error(t, err)
   159  			} else {
   160  				require.NoError(t, err)
   161  				assert.Equal(t, test.expectedType, typ)
   162  			}
   163  		})
   164  	}
   165  }
   166  
   167  func TestNumberConvert(t *testing.T) {
   168  	tests := []struct {
   169  		typ     sql.Type
   170  		inp     interface{}
   171  		exp     interface{}
   172  		err     bool
   173  		inRange sql.ConvertInRange
   174  	}{
   175  		{typ: Boolean, inp: true, exp: int8(1), err: false, inRange: sql.InRange},
   176  		{typ: Int8, inp: int32(0), exp: int8(0), err: false, inRange: sql.InRange},
   177  		{typ: Int16, inp: uint16(1), exp: int16(1), err: false, inRange: sql.InRange},
   178  		{typ: Int24, inp: false, exp: int32(0), err: false, inRange: sql.InRange},
   179  		{typ: Int32, inp: nil, exp: nil, err: false, inRange: sql.InRange},
   180  		{typ: Int64, inp: "33", exp: int64(33), err: false, inRange: sql.InRange},
   181  		{typ: Int64, inp: "33.0", exp: int64(33), err: false, inRange: sql.InRange},
   182  		{typ: Int64, inp: "33.1", exp: int64(33), err: false, inRange: sql.InRange},
   183  		{typ: Int64, inp: strconv.FormatInt(math.MaxInt64, 10), exp: int64(math.MaxInt64), err: false, inRange: sql.InRange},
   184  		{typ: Int64, inp: true, exp: int64(1), err: false, inRange: sql.InRange},
   185  		{typ: Int64, inp: false, exp: int64(0), err: false, inRange: sql.InRange},
   186  		{typ: Uint8, inp: int64(34), exp: uint8(34), err: false, inRange: sql.InRange},
   187  		{typ: Uint16, inp: int16(35), exp: uint16(35), err: false, inRange: sql.InRange},
   188  		{typ: Uint24, inp: 36.756, exp: uint32(37), err: false, inRange: sql.InRange},
   189  		{typ: Uint32, inp: uint8(37), exp: uint32(37), err: false, inRange: sql.InRange},
   190  		{typ: Uint64, inp: time.Date(2009, 1, 2, 3, 4, 5, 0, time.UTC), exp: uint64(time.Date(2009, 1, 2, 3, 4, 5, 0, time.UTC).Unix()), err: false, inRange: sql.InRange},
   191  		{typ: Uint64, inp: "01000", exp: uint64(1000), err: false, inRange: sql.InRange},
   192  		{typ: Uint64, inp: true, exp: uint64(1), err: false, inRange: sql.InRange},
   193  		{typ: Uint64, inp: false, exp: uint64(0), err: false, inRange: sql.InRange},
   194  		{typ: Float32, inp: "22.25", exp: float32(22.25), err: false, inRange: sql.InRange},
   195  		{typ: Float32, inp: []byte{90, 140, 228, 206, 116}, exp: float32(388910861940), err: false, inRange: sql.InRange},
   196  		{typ: Float64, inp: float32(893.875), exp: float64(893.875), err: false, inRange: sql.InRange},
   197  		{typ: Boolean, inp: math.MaxInt8 + 1, exp: int8(math.MaxInt8), err: false, inRange: sql.OutOfRange},
   198  		{typ: Int8, inp: math.MaxInt8 + 1, exp: int8(math.MaxInt8), err: false, inRange: sql.OutOfRange},
   199  		{typ: Int8, inp: math.MinInt8 - 1, exp: int8(math.MinInt8), err: false, inRange: sql.OutOfRange},
   200  		{typ: Int16, inp: math.MaxInt16 + 1, exp: int16(math.MaxInt16), err: false, inRange: sql.OutOfRange},
   201  		{typ: Int16, inp: math.MinInt16 - 1, exp: int16(math.MinInt16), err: false, inRange: sql.OutOfRange},
   202  		{typ: Int24, inp: 1 << 24, exp: int32(1<<23 - 1), err: false, inRange: sql.OutOfRange},
   203  		{typ: Int24, inp: -1 << 24, exp: int32(-1 << 23), err: false, inRange: sql.OutOfRange},
   204  		{typ: Int32, inp: math.MaxInt32 + 1, exp: int32(math.MaxInt32), err: false, inRange: sql.OutOfRange},
   205  		{typ: Int32, inp: math.MinInt32 - 1, exp: int32(math.MinInt32), err: false, inRange: sql.OutOfRange},
   206  		{typ: Int64, inp: uint64(math.MaxInt64 + 1), exp: int64(math.MaxInt64), err: false, inRange: sql.OutOfRange},
   207  		{typ: Uint8, inp: math.MaxUint8 + 1, exp: uint8(math.MaxUint8), err: false, inRange: sql.OutOfRange},
   208  		{typ: Uint8, inp: -1, exp: uint8(math.MaxUint8), err: false, inRange: sql.OutOfRange},
   209  		{typ: Uint16, inp: math.MaxUint16 + 1, exp: uint16(math.MaxUint16), err: false, inRange: sql.OutOfRange},
   210  		{typ: Uint16, inp: -1, exp: uint16(math.MaxUint16), err: false, inRange: sql.OutOfRange},
   211  		{typ: Uint24, inp: 1 << 24, exp: uint32(1<<24 - 1), err: false, inRange: sql.OutOfRange},
   212  		{typ: Uint24, inp: -1, exp: uint32(1<<24 - 1), err: false, inRange: sql.OutOfRange},
   213  		{typ: Uint32, inp: math.MaxUint32 + 1, exp: uint32(math.MaxUint32), err: false, inRange: sql.OutOfRange},
   214  		{typ: Uint32, inp: -1, exp: uint32(math.MaxUint32), err: false, inRange: sql.OutOfRange},
   215  		{typ: Uint64, inp: -1, exp: uint64(math.MaxUint64), err: false, inRange: sql.OutOfRange},
   216  		{typ: Float32, inp: math.MaxFloat32 * 2, exp: float32(math.MaxFloat32), err: false, inRange: sql.OutOfRange},
   217  	}
   218  
   219  	for _, test := range tests {
   220  		t.Run(fmt.Sprintf("%v %v %v", test.typ, test.inp, test.exp), func(t *testing.T) {
   221  			val, inRange, err := test.typ.Convert(test.inp)
   222  			if test.err {
   223  				assert.Error(t, err)
   224  			} else {
   225  				require.NoError(t, err)
   226  				assert.Equal(t, test.exp, val)
   227  				assert.Equal(t, test.inRange, inRange)
   228  				if val != nil {
   229  					assert.Equal(t, test.typ.ValueType(), reflect.TypeOf(val))
   230  				}
   231  			}
   232  		})
   233  	}
   234  }
   235  
   236  func TestNumberSQL_BooleanFromBoolean(t *testing.T) {
   237  	val, err := Boolean.SQL(sql.NewEmptyContext(), nil, true)
   238  	require.NoError(t, err)
   239  	assert.Equal(t, "INT8(1)", val.String())
   240  
   241  	val, err = Boolean.SQL(sql.NewEmptyContext(), nil, false)
   242  	require.NoError(t, err)
   243  	assert.Equal(t, "INT8(0)", val.String())
   244  }
   245  
   246  func TestNumberSQL_NumberFromString(t *testing.T) {
   247  	val, err := Int64.SQL(sql.NewEmptyContext(), nil, "not a number")
   248  	require.NoError(t, err)
   249  	assert.Equal(t, "not a number", val.ToString())
   250  
   251  	val, err = Float64.SQL(sql.NewEmptyContext(), nil, "also not a number")
   252  	require.NoError(t, err)
   253  	assert.Equal(t, "also not a number", val.ToString())
   254  }
   255  
   256  func TestNumberString(t *testing.T) {
   257  	tests := []struct {
   258  		typ         sql.Type
   259  		expectedStr string
   260  	}{
   261  		{Boolean, "tinyint(1)"},
   262  		{Int8, "tinyint"},
   263  		{Int16, "smallint"},
   264  		{Int24, "mediumint"},
   265  		{Int32, "int"},
   266  		{Int64, "bigint"},
   267  		{Uint8, "tinyint unsigned"},
   268  		{Uint16, "smallint unsigned"},
   269  		{Uint24, "mediumint unsigned"},
   270  		{Uint32, "int unsigned"},
   271  		{Uint64, "bigint unsigned"},
   272  		{Float32, "float"},
   273  		{Float64, "double"},
   274  	}
   275  
   276  	for _, test := range tests {
   277  		t.Run(fmt.Sprintf("%v %v", test.typ, test.expectedStr), func(t *testing.T) {
   278  			str := test.typ.String()
   279  			assert.Equal(t, test.expectedStr, str)
   280  		})
   281  	}
   282  }