github.com/dolthub/go-mysql-server@v0.18.0/sql/types/conversion.go (about) 1 // Copyright 2022 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package types 16 17 import ( 18 "fmt" 19 "strconv" 20 "strings" 21 "time" 22 23 "github.com/dolthub/vitess/go/sqltypes" 24 "github.com/dolthub/vitess/go/vt/proto/query" 25 "github.com/dolthub/vitess/go/vt/sqlparser" 26 "github.com/shopspring/decimal" 27 "gopkg.in/src-d/go-errors.v1" 28 29 "github.com/dolthub/go-mysql-server/sql" 30 ) 31 32 // ApproximateTypeFromValue returns the closest matching type to the given value. For example, an int16 will return SMALLINT. 33 func ApproximateTypeFromValue(val interface{}) sql.Type { 34 switch v := val.(type) { 35 case bool: 36 return Boolean 37 case int: 38 if strconv.IntSize == 32 { 39 return Int32 40 } 41 return Int64 42 case int64: 43 return Int64 44 case int32: 45 return Int32 46 case int16: 47 return Int16 48 case int8: 49 return Int8 50 case uint: 51 if strconv.IntSize == 32 { 52 return Uint32 53 } 54 return Uint64 55 case uint64: 56 return Uint64 57 case uint32: 58 return Uint32 59 case uint16: 60 return Uint16 61 case uint8: 62 return Uint8 63 case Timespan, time.Duration: 64 return Time 65 case time.Time: 66 return DatetimeMaxPrecision 67 case float32: 68 return Float32 69 case float64: 70 return Float64 71 case string: 72 typ, err := CreateString(sqltypes.VarChar, int64(len(v)), sql.Collation_Default) 73 if err != nil { 74 typ, err = CreateString(sqltypes.Text, int64(len(v)), sql.Collation_Default) 75 if err != nil { 76 typ = LongText 77 } 78 } 79 return typ 80 case []byte: 81 typ, err := CreateBinary(sqltypes.VarBinary, int64(len(v))) 82 if err != nil { 83 typ, err = CreateBinary(sqltypes.Blob, int64(len(v))) 84 if err != nil { 85 typ = LongBlob 86 } 87 } 88 return typ 89 case decimal.Decimal: 90 str := v.String() 91 dotIdx := strings.Index(str, ".") 92 if len(str) > 66 { 93 return Float64 94 } else if dotIdx == -1 { 95 typ, err := CreateDecimalType(uint8(len(str)), 0) 96 if err != nil { 97 return Float64 98 } 99 return typ 100 } else { 101 precision := uint8(len(str) - 1) 102 scale := uint8(len(str) - dotIdx - 1) 103 typ, err := CreateDecimalType(precision, scale) 104 if err != nil { 105 return Float64 106 } 107 return typ 108 } 109 case decimal.NullDecimal: 110 if !v.Valid { 111 return Float64 112 } 113 return ApproximateTypeFromValue(v.Decimal) 114 case nil: 115 return Null 116 default: 117 return LongText 118 } 119 } 120 121 // IsBinary returns whether the type represents binary data. 122 func IsBinary(sqlType query.Type) bool { 123 switch sqlType { 124 case sqltypes.Binary, 125 sqltypes.VarBinary, 126 sqltypes.Blob, 127 sqltypes.TypeJSON, 128 sqltypes.Geometry: 129 return true 130 } 131 return false 132 } 133 134 func allowsCharSet(sqlType query.Type) bool { 135 switch sqlType { 136 case sqltypes.VarChar, 137 sqltypes.Char, 138 sqltypes.Text, 139 sqltypes.Enum, 140 sqltypes.Set: 141 return true 142 } 143 return false 144 } 145 146 var ErrCharacterSetOnInvalidType = errors.NewKind("Only character columns, enums, and sets can have a CHARACTER SET option") 147 148 // ColumnTypeToType gets the column type using the column definition. 149 func ColumnTypeToType(ct *sqlparser.ColumnType) (sql.Type, error) { 150 if resolvedType, ok := ct.ResolvedType.(sql.Type); ok { 151 return resolvedType, nil 152 } 153 sqlType := ct.SQLType() 154 155 if !allowsCharSet(sqlType) && len(ct.Charset) != 0 { 156 return nil, ErrCharacterSetOnInvalidType.New() 157 } 158 159 collate := ct.Collate 160 if IsBinary(sqlType) && collate == "" { 161 collate = sql.Collation_binary.Name() 162 } 163 164 switch strings.ToLower(ct.Type) { 165 case "boolean", "bool": 166 return Boolean, nil 167 case "tinyint": 168 if ct.Length != nil { 169 displayWidth, err := strconv.Atoi(string(ct.Length.Val)) 170 if err != nil { 171 return nil, fmt.Errorf("unable to parse display width value: %w", err) 172 } 173 174 // As of MySQL 8.1.0, TINYINT is the only integer type for which MySQL will retain a display width, 175 // and ONLY if it's 1. All other types and display width values are dropped. TINYINT(1) seems to be 176 // left for backwards compatibility with ORM tools like ActiveRecord that rely on it for mapping to 177 // a boolean type. 178 if !ct.Unsigned && displayWidth == 1 { 179 return Boolean, nil 180 } 181 } 182 183 if ct.Unsigned { 184 return Uint8, nil 185 } 186 return Int8, nil 187 case "smallint": 188 if ct.Unsigned { 189 return Uint16, nil 190 } 191 return Int16, nil 192 case "mediumint": 193 if ct.Unsigned { 194 return Uint24, nil 195 } 196 return Int24, nil 197 case "int", "integer": 198 if ct.Unsigned { 199 return Uint32, nil 200 } 201 return Int32, nil 202 case "bigint": 203 if ct.Unsigned { 204 return Uint64, nil 205 } 206 return Int64, nil 207 case "float": 208 if ct.Length != nil { 209 precision, err := strconv.ParseInt(string(ct.Length.Val), 10, 8) 210 if err != nil { 211 return nil, err 212 } 213 if precision > 53 || precision < 0 { 214 return nil, sql.ErrInvalidColTypeDefinition.New(ct.String(), "Valid range for precision is 0-24 or 25-53") 215 } else if precision > 24 { 216 return Float64, nil 217 } else { 218 return Float32, nil 219 } 220 } 221 return Float32, nil 222 case "double", "real", "double precision": 223 return Float64, nil 224 case "decimal", "fixed", "dec", "numeric": 225 precision := int64(0) 226 scale := int64(0) 227 if ct.Length != nil { 228 var err error 229 precision, err = strconv.ParseInt(string(ct.Length.Val), 10, 8) 230 if err != nil { 231 return nil, err 232 } 233 } 234 if ct.Scale != nil { 235 var err error 236 scale, err = strconv.ParseInt(string(ct.Scale.Val), 10, 8) 237 if err != nil { 238 return nil, err 239 } 240 } 241 return CreateColumnDecimalType(uint8(precision), uint8(scale)) 242 case "bit": 243 length := int64(1) 244 if ct.Length != nil { 245 var err error 246 length, err = strconv.ParseInt(string(ct.Length.Val), 10, 8) 247 if err != nil { 248 return nil, err 249 } 250 } 251 return CreateBitType(uint8(length)) 252 case "tinytext", "tinyblob": 253 collation, err := sql.ParseCollation(&ct.Charset, &collate, ct.BinaryCollate) 254 if err != nil { 255 return nil, err 256 } 257 return CreateString(sqltypes.Text, TinyTextBlobMax/collation.CharacterSet().MaxLength(), collation) 258 case "text", "blob": 259 collation, err := sql.ParseCollation(&ct.Charset, &collate, ct.BinaryCollate) 260 if err != nil { 261 return nil, err 262 } 263 if ct.Length == nil { 264 return CreateString(sqltypes.Text, TextBlobMax/collation.CharacterSet().MaxLength(), collation) 265 } 266 length, err := strconv.ParseInt(string(ct.Length.Val), 10, 64) 267 if err != nil { 268 return nil, err 269 } 270 return CreateString(sqltypes.Text, length, collation) 271 case "mediumtext", "mediumblob", "long", "long varchar": 272 collation, err := sql.ParseCollation(&ct.Charset, &collate, ct.BinaryCollate) 273 if err != nil { 274 return nil, err 275 } 276 return CreateString(sqltypes.Text, MediumTextBlobMax/collation.CharacterSet().MaxLength(), collation) 277 case "longtext", "longblob": 278 collation, err := sql.ParseCollation(&ct.Charset, &collate, ct.BinaryCollate) 279 if err != nil { 280 return nil, err 281 } 282 return CreateString(sqltypes.Text, LongTextBlobMax/collation.CharacterSet().MaxLength(), collation) 283 case "char", "character", "binary": 284 collation, err := sql.ParseCollation(&ct.Charset, &collate, ct.BinaryCollate) 285 if err != nil { 286 return nil, err 287 } 288 length := int64(1) 289 if ct.Length != nil { 290 var err error 291 length, err = strconv.ParseInt(string(ct.Length.Val), 10, 64) 292 if err != nil { 293 return nil, err 294 } 295 } 296 return CreateString(sqltypes.Char, length, collation) 297 case "nchar", "national char", "national character": 298 length := int64(1) 299 if ct.Length != nil { 300 var err error 301 length, err = strconv.ParseInt(string(ct.Length.Val), 10, 64) 302 if err != nil { 303 return nil, err 304 } 305 } 306 return CreateString(sqltypes.Char, length, sql.Collation_utf8mb3_general_ci) 307 case "varchar", "char varying", "character varying": 308 collation, err := sql.ParseCollation(&ct.Charset, &collate, ct.BinaryCollate) 309 if err != nil { 310 return nil, err 311 } 312 if ct.Length == nil { 313 return nil, fmt.Errorf("VARCHAR requires a length") 314 } 315 316 var strLen = string(ct.Length.Val) 317 var length int64 318 if strings.ToLower(strLen) == "max" { 319 length = 16383 320 } else { 321 length, err = strconv.ParseInt(strLen, 10, 64) 322 if err != nil { 323 return nil, err 324 } 325 } 326 return CreateString(sqltypes.VarChar, length, collation) 327 case "nchar varchar", "nchar varying", "nvarchar", "national varchar", "national char varying", "national character varying": 328 if ct.Length == nil { 329 return nil, fmt.Errorf("VARCHAR requires a length") 330 } 331 length, err := strconv.ParseInt(string(ct.Length.Val), 10, 64) 332 if err != nil { 333 return nil, err 334 } 335 return CreateString(sqltypes.VarChar, length, sql.Collation_utf8mb3_general_ci) 336 case "varbinary": 337 collation, err := sql.ParseCollation(&ct.Charset, &collate, ct.BinaryCollate) 338 if err != nil { 339 return nil, err 340 } 341 if ct.Length == nil { 342 return nil, fmt.Errorf("VARBINARY requires a length") 343 } 344 length, err := strconv.ParseInt(string(ct.Length.Val), 10, 64) 345 if err != nil { 346 return nil, err 347 } 348 // we need to have a separate check for varbinary, as CreateString checks varbinary against json limit 349 if length > varcharVarbinaryMax { 350 return nil, ErrLengthTooLarge.New(length, varcharVarbinaryMax) 351 } 352 return CreateString(sqltypes.VarBinary, length, collation) 353 case "year": 354 return Year, nil 355 case "date": 356 return CreateDatetimeType(sqltypes.Date, 0) 357 case "time": 358 if ct.Length != nil { 359 length, err := strconv.ParseInt(string(ct.Length.Val), 10, 64) 360 if err != nil { 361 return nil, err 362 } 363 switch length { 364 case 0, 1, 2, 3, 4, 5: 365 return nil, fmt.Errorf("TIME length not yet supported") 366 case 6: 367 return Time, nil 368 default: 369 return nil, fmt.Errorf("TIME only supports a length from 0 to 6") 370 } 371 } 372 return Time, nil 373 case "timestamp": 374 precision := int64(0) 375 if ct.Length != nil { 376 var err error 377 precision, err = strconv.ParseInt(string(ct.Length.Val), 10, 64) 378 if err != nil { 379 return nil, err 380 } 381 382 if precision > 6 || precision < 0 { 383 return nil, fmt.Errorf("TIMESTAMP supports precision from 0 to 6") 384 } 385 } 386 387 return CreateDatetimeType(sqltypes.Timestamp, int(precision)) 388 case "datetime": 389 precision := int64(0) 390 if ct.Length != nil { 391 var err error 392 precision, err = strconv.ParseInt(string(ct.Length.Val), 10, 64) 393 if err != nil { 394 return nil, err 395 } 396 397 if precision > 6 || precision < 0 { 398 return nil, fmt.Errorf("DATETIME supports precision from 0 to 6") 399 } 400 } 401 402 return CreateDatetimeType(sqltypes.Datetime, int(precision)) 403 case "enum": 404 collation, err := sql.ParseCollation(&ct.Charset, &collate, ct.BinaryCollate) 405 if err != nil { 406 return nil, err 407 } 408 if collation.Sorter() == nil { 409 return nil, sql.ErrCollationNotYetImplementedTemp.New(collation.Name()) 410 } 411 return CreateEnumType(ct.EnumValues, collation) 412 case "set": 413 collation, err := sql.ParseCollation(&ct.Charset, &collate, ct.BinaryCollate) 414 if err != nil { 415 return nil, err 416 } 417 if collation.Sorter() == nil { 418 return nil, sql.ErrCollationNotYetImplementedTemp.New(collation.Name()) 419 } 420 return CreateSetType(ct.EnumValues, collation) 421 case "json": 422 return JSON, nil 423 case "geometry": 424 return GeometryType{}, nil 425 case "geometrycollection": 426 return GeomCollType{}, nil 427 case "linestring": 428 return LineStringType{}, nil 429 case "multilinestring": 430 return MultiLineStringType{}, nil 431 case "point": 432 return PointType{}, nil 433 case "multipoint": 434 return MultiPointType{}, nil 435 case "polygon": 436 return PolygonType{}, nil 437 case "multipolygon": 438 return MultiPolygonType{}, nil 439 default: 440 return nil, fmt.Errorf("unknown type: %v", ct.Type) 441 } 442 return nil, fmt.Errorf("type not yet implemented: %v", ct.Type) 443 } 444 445 // CompareNulls compares two values, and returns true if either is null. 446 // The returned integer represents the ordering, with a rule that states nulls 447 // as being ordered before non-nulls. 448 func CompareNulls(a interface{}, b interface{}) (bool, int) { 449 aIsNull := a == nil 450 bIsNull := b == nil 451 if aIsNull && bIsNull { 452 return true, 0 453 } else if aIsNull && !bIsNull { 454 return true, 1 455 } else if !aIsNull && bIsNull { 456 return true, -1 457 } 458 return false, 0 459 } 460 461 // NumColumns returns the number of columns in a type. This is one for all 462 // types, except tuples. 463 func NumColumns(t sql.Type) int { 464 v, ok := t.(TupleType) 465 if !ok { 466 return 1 467 } 468 return len(v) 469 } 470 471 // ErrIfMismatchedColumns returns an operand error if the number of columns in 472 // t1 is not equal to the number of columns in t2. If the number of columns is 473 // equal, and both types are tuple types, it recurses into each subtype, 474 // asserting that those subtypes are structurally identical as well. 475 func ErrIfMismatchedColumns(t1, t2 sql.Type) error { 476 if NumColumns(t1) != NumColumns(t2) { 477 return sql.ErrInvalidOperandColumns.New(NumColumns(t1), NumColumns(t2)) 478 } 479 v1, ok1 := t1.(TupleType) 480 v2, ok2 := t2.(TupleType) 481 if ok1 && ok2 { 482 for i := range v1 { 483 if err := ErrIfMismatchedColumns(v1[i], v2[i]); err != nil { 484 return err 485 } 486 } 487 } 488 return nil 489 } 490 491 // ErrIfMismatchedColumnsInTuple returns an operand error is t2 is not a tuple 492 // type whose subtypes are structurally identical to t1. 493 func ErrIfMismatchedColumnsInTuple(t1, t2 sql.Type) error { 494 v2, ok2 := t2.(TupleType) 495 if !ok2 { 496 return sql.ErrInvalidOperandColumns.New(NumColumns(t1), NumColumns(t2)) 497 } 498 for _, v := range v2 { 499 if err := ErrIfMismatchedColumns(t1, v); err != nil { 500 return err 501 } 502 } 503 return nil 504 } 505 506 // TypesEqual compares two Types and returns whether they are equivalent. 507 func TypesEqual(a, b sql.Type) bool { 508 // TODO: replace all of the Type() == Type() calls with TypesEqual 509 510 // We can assume they have the same implementing type if this passes, so we have to check the parameters 511 if a == nil || b == nil || a.Type() != b.Type() { 512 return false 513 } 514 // Some types cannot be compared structurally as they contain non-comparable types (such as slices), so we handle 515 // those separately. 516 switch at := a.(type) { 517 case EnumType: 518 aEnumType := at 519 bEnumType := b.(EnumType) 520 if len(aEnumType.indexToVal) != len(bEnumType.indexToVal) { 521 return false 522 } 523 for i := 0; i < len(aEnumType.indexToVal); i++ { 524 if aEnumType.indexToVal[i] != bEnumType.indexToVal[i] { 525 return false 526 } 527 } 528 return aEnumType.collation == bEnumType.collation 529 case SetType: 530 aSetType := at 531 bSetType := b.(SetType) 532 if len(aSetType.bitToVal) != len(bSetType.bitToVal) { 533 return false 534 } 535 for bit, aVal := range aSetType.bitToVal { 536 if bVal, ok := bSetType.bitToVal[bit]; ok && aVal != bVal { 537 return false 538 } 539 } 540 return aSetType.collation == bSetType.collation 541 case TupleType: 542 if tupA, ok := a.(TupleType); ok { 543 if tupB, ok := b.(TupleType); ok && len(tupA) == len(tupB) { 544 for i := range tupA { 545 if !TypesEqual(tupA[i], tupB[i]) { 546 return false 547 } 548 } 549 return true 550 } 551 } 552 return false 553 default: 554 return a == b 555 } 556 }