github.com/dolthub/go-mysql-server@v0.18.0/sql/types/conversion_test.go (about)

     1  // Copyright 2022 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package types
    16  
    17  import (
    18  	"fmt"
    19  	"testing"
    20  
    21  	"github.com/dolthub/vitess/go/sqltypes"
    22  	"github.com/dolthub/vitess/go/vt/sqlparser"
    23  	"github.com/stretchr/testify/assert"
    24  
    25  	"github.com/dolthub/go-mysql-server/sql"
    26  )
    27  
    28  func TestFloatCovert(t *testing.T) {
    29  	tests := []struct {
    30  		length   string
    31  		scale    string
    32  		expected sql.Type
    33  		err      bool
    34  	}{
    35  		{"20", "2", Float32, false},
    36  		{"-1", "", nil, true},
    37  		{"54", "", nil, true},
    38  		{"", "", Float32, false},
    39  		{"0", "", Float32, false},
    40  		{"24", "", Float32, false},
    41  		{"25", "", Float64, false},
    42  		{"53", "", Float64, false},
    43  	}
    44  
    45  	for _, test := range tests {
    46  		t.Run(fmt.Sprintf("%v %v %v", test.length, test.scale, test.err), func(t *testing.T) {
    47  			var precision *sqlparser.SQLVal = nil
    48  			var scale *sqlparser.SQLVal = nil
    49  
    50  			if test.length != "" {
    51  				precision = &sqlparser.SQLVal{
    52  					Type: sqlparser.IntVal,
    53  					Val:  []byte(test.length),
    54  				}
    55  			}
    56  
    57  			if test.scale != "" {
    58  				scale = &sqlparser.SQLVal{
    59  					Type: sqlparser.IntVal,
    60  					Val:  []byte(test.scale),
    61  				}
    62  			}
    63  
    64  			ct := &sqlparser.ColumnType{
    65  				Type:   "FLOAT",
    66  				Scale:  scale,
    67  				Length: precision,
    68  			}
    69  			res, err := ColumnTypeToType(ct)
    70  			if test.err {
    71  				assert.Error(t, err)
    72  			} else {
    73  				assert.Equal(t, test.expected, res)
    74  			}
    75  		})
    76  	}
    77  }
    78  
    79  func TestColumnTypeToType_Time(t *testing.T) {
    80  	tests := []struct {
    81  		length   string
    82  		expected sql.Type
    83  		err      bool
    84  	}{
    85  		{"", Time, false},
    86  		{"0", nil, true},
    87  		{"1", nil, true},
    88  		{"2", nil, true},
    89  		{"3", nil, true},
    90  		{"4", nil, true},
    91  		{"5", nil, true},
    92  		{"6", Time, false},
    93  		{"7", nil, true},
    94  	}
    95  
    96  	for _, test := range tests {
    97  		t.Run(fmt.Sprintf("%v %v", test.length, test.err), func(t *testing.T) {
    98  			var precision *sqlparser.SQLVal
    99  
   100  			if test.length != "" {
   101  				precision = &sqlparser.SQLVal{
   102  					Type: sqlparser.IntVal,
   103  					Val:  []byte(test.length),
   104  				}
   105  			}
   106  
   107  			ct := &sqlparser.ColumnType{
   108  				Type:   "TIME",
   109  				Length: precision,
   110  			}
   111  			res, err := ColumnTypeToType(ct)
   112  			if test.err {
   113  				assert.Error(t, err)
   114  			} else {
   115  				assert.Equal(t, test.expected, res)
   116  			}
   117  		})
   118  	}
   119  }
   120  
   121  func TestColumnCharTypes(t *testing.T) {
   122  	test := []struct {
   123  		typ string
   124  		len int64
   125  		exp sql.Type
   126  	}{
   127  		{
   128  			typ: "nchar varchar",
   129  			len: 10,
   130  			exp: StringType{baseType: sqltypes.VarChar, maxCharLength: 10, maxByteLength: 30, collation: 33},
   131  		},
   132  		{
   133  			typ: "char varying",
   134  			len: 10,
   135  			exp: StringType{baseType: sqltypes.VarChar, maxCharLength: 10, maxByteLength: 40},
   136  		},
   137  		{
   138  			typ: "nchar varying",
   139  			len: 10,
   140  			exp: StringType{baseType: sqltypes.VarChar, maxCharLength: 10, maxByteLength: 30, collation: 33},
   141  		},
   142  		{
   143  			typ: "national char varying",
   144  			len: 10,
   145  			exp: StringType{baseType: sqltypes.VarChar, maxCharLength: 10, maxByteLength: 30, collation: 33},
   146  		},
   147  	}
   148  
   149  	for _, test := range test {
   150  		t.Run(fmt.Sprintf("%v %v", test.typ, test.exp), func(t *testing.T) {
   151  			ct := &sqlparser.ColumnType{
   152  				Type:   test.typ,
   153  				Length: &sqlparser.SQLVal{Type: sqlparser.IntVal, Val: []byte(fmt.Sprintf("%v", test.len))},
   154  			}
   155  			res, err := ColumnTypeToType(ct)
   156  			assert.NoError(t, err)
   157  			assert.Equal(t, test.exp, res)
   158  		})
   159  	}
   160  }