github.com/XiaoMi/Gaea@v1.2.5/parser/tidb-types/convert.go (about) 1 // Copyright 2014 The ql Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSES/QL-LICENSE file. 4 5 // Copyright 2015 PingCAP, Inc. 6 // 7 // Licensed under the Apache License, Version 2.0 (the "License"); 8 // you may not use this file except in compliance with the License. 9 // You may obtain a copy of the License at 10 // 11 // http://www.apache.org/licenses/LICENSE-2.0 12 // 13 // Unless required by applicable law or agreed to in writing, software 14 // distributed under the License is distributed on an "AS IS" BASIS, 15 // See the License for the specific language governing permissions and 16 // limitations under the License. 17 18 package types 19 20 import ( 21 "math" 22 "strconv" 23 "strings" 24 25 "github.com/pingcap/errors" 26 27 "github.com/XiaoMi/Gaea/mysql" 28 "github.com/XiaoMi/Gaea/parser/stmtctx" 29 "github.com/XiaoMi/Gaea/parser/terror" 30 "github.com/XiaoMi/Gaea/parser/tidb-types/json" 31 "github.com/XiaoMi/Gaea/util/hack" 32 ) 33 34 func truncateStr(str string, flen int) string { 35 if flen != UnspecifiedLength && len(str) > flen { 36 str = str[:flen] 37 } 38 return str 39 } 40 41 // UnsignedUpperBound indicates the max uint64 values of different mysql types. 42 var UnsignedUpperBound = map[byte]uint64{ 43 mysql.TypeTiny: math.MaxUint8, 44 mysql.TypeShort: math.MaxUint16, 45 mysql.TypeInt24: mysql.MaxUint24, 46 mysql.TypeLong: math.MaxUint32, 47 mysql.TypeLonglong: math.MaxUint64, 48 mysql.TypeBit: math.MaxUint64, 49 mysql.TypeEnum: math.MaxUint64, 50 mysql.TypeSet: math.MaxUint64, 51 } 52 53 // SignedUpperBound indicates the max int64 values of different mysql types. 54 var SignedUpperBound = map[byte]int64{ 55 mysql.TypeTiny: math.MaxInt8, 56 mysql.TypeShort: math.MaxInt16, 57 mysql.TypeInt24: mysql.MaxInt24, 58 mysql.TypeLong: math.MaxInt32, 59 mysql.TypeLonglong: math.MaxInt64, 60 } 61 62 // SignedLowerBound indicates the min int64 values of different mysql types. 63 var SignedLowerBound = map[byte]int64{ 64 mysql.TypeTiny: math.MinInt8, 65 mysql.TypeShort: math.MinInt16, 66 mysql.TypeInt24: mysql.MinInt24, 67 mysql.TypeLong: math.MinInt32, 68 mysql.TypeLonglong: math.MinInt64, 69 } 70 71 // ConvertFloatToInt converts a float64 value to a int value. 72 func ConvertFloatToInt(fval float64, lowerBound, upperBound int64, tp byte) (int64, error) { 73 val := RoundFloat(fval) 74 if val < float64(lowerBound) { 75 return lowerBound, overflow(val, tp) 76 } 77 78 if val >= float64(upperBound) { 79 if val == float64(upperBound) { 80 return upperBound, nil 81 } 82 return upperBound, overflow(val, tp) 83 } 84 return int64(val), nil 85 } 86 87 // ConvertIntToInt converts an int value to another int value of different precision. 88 func ConvertIntToInt(val int64, lowerBound int64, upperBound int64, tp byte) (int64, error) { 89 if val < lowerBound { 90 return lowerBound, overflow(val, tp) 91 } 92 93 if val > upperBound { 94 return upperBound, overflow(val, tp) 95 } 96 97 return val, nil 98 } 99 100 // ConvertUintToInt converts an uint value to an int value. 101 func ConvertUintToInt(val uint64, upperBound int64, tp byte) (int64, error) { 102 if val > uint64(upperBound) { 103 return upperBound, overflow(val, tp) 104 } 105 106 return int64(val), nil 107 } 108 109 // ConvertIntToUint converts an int value to an uint value. 110 func ConvertIntToUint(sc *stmtctx.StatementContext, val int64, upperBound uint64, tp byte) (uint64, error) { 111 if sc.ShouldClipToZero() && val < 0 { 112 return 0, overflow(val, tp) 113 } 114 115 if uint64(val) > upperBound { 116 return upperBound, overflow(val, tp) 117 } 118 119 return uint64(val), nil 120 } 121 122 // ConvertUintToUint converts an uint value to another uint value of different precision. 123 func ConvertUintToUint(val uint64, upperBound uint64, tp byte) (uint64, error) { 124 if val > upperBound { 125 return upperBound, overflow(val, tp) 126 } 127 128 return val, nil 129 } 130 131 // ConvertFloatToUint converts a float value to an uint value. 132 func ConvertFloatToUint(sc *stmtctx.StatementContext, fval float64, upperBound uint64, tp byte) (uint64, error) { 133 val := RoundFloat(fval) 134 if val < 0 { 135 if sc.ShouldClipToZero() { 136 return 0, overflow(val, tp) 137 } 138 return uint64(int64(val)), overflow(val, tp) 139 } 140 141 if val > float64(upperBound) { 142 return upperBound, overflow(val, tp) 143 } 144 return uint64(val), nil 145 } 146 147 // StrToInt converts a string to an integer at the best-effort. 148 func StrToInt(sc *stmtctx.StatementContext, str string) (int64, error) { 149 str = strings.TrimSpace(str) 150 validPrefix, err := getValidIntPrefix(sc, str) 151 iVal, err1 := strconv.ParseInt(validPrefix, 10, 64) 152 if err1 != nil { 153 return iVal, ErrOverflow.GenWithStackByArgs("BIGINT", validPrefix) 154 } 155 return iVal, errors.Trace(err) 156 } 157 158 // StrToUint converts a string to an unsigned integer at the best-effortt. 159 func StrToUint(sc *stmtctx.StatementContext, str string) (uint64, error) { 160 str = strings.TrimSpace(str) 161 validPrefix, err := getValidIntPrefix(sc, str) 162 if validPrefix[0] == '+' { 163 validPrefix = validPrefix[1:] 164 } 165 uVal, err1 := strconv.ParseUint(validPrefix, 10, 64) 166 if err1 != nil { 167 return uVal, ErrOverflow.GenWithStackByArgs("BIGINT UNSIGNED", validPrefix) 168 } 169 return uVal, errors.Trace(err) 170 } 171 172 // StrToDateTime converts str to MySQL DateTime. 173 func StrToDateTime(sc *stmtctx.StatementContext, str string, fsp int) (Time, error) { 174 return ParseTime(sc, str, mysql.TypeDatetime, fsp) 175 } 176 177 // StrToDuration converts str to Duration. It returns Duration in normal case, 178 // and returns Time when str is in datetime format. 179 // when isDuration is true, the d is returned, when it is false, the t is returned. 180 // See https://dev.mysql.com/doc/refman/5.5/en/date-and-time-literals.html. 181 func StrToDuration(sc *stmtctx.StatementContext, str string, fsp int) (d Duration, t Time, isDuration bool, err error) { 182 str = strings.TrimSpace(str) 183 length := len(str) 184 if length > 0 && str[0] == '-' { 185 length-- 186 } 187 // Timestamp format is 'YYYYMMDDHHMMSS' or 'YYMMDDHHMMSS', which length is 12. 188 // See #3923, it explains what we do here. 189 if length >= 12 { 190 t, err = StrToDateTime(sc, str, fsp) 191 if err == nil { 192 return d, t, false, nil 193 } 194 } 195 196 d, err = ParseDuration(sc, str, fsp) 197 if ErrTruncatedWrongVal.Equal(err) { 198 err = sc.HandleTruncate(err) 199 } 200 return d, t, true, errors.Trace(err) 201 } 202 203 // NumberToDuration converts number to Duration. 204 func NumberToDuration(number int64, fsp int) (Duration, error) { 205 if number > TimeMaxValue { 206 // Try to parse DATETIME. 207 if number >= 10000000000 { // '2001-00-00 00-00-00' 208 if t, err := ParseDatetimeFromNum(nil, number); err == nil { 209 dur, err1 := t.ConvertToDuration() 210 return dur, errors.Trace(err1) 211 } 212 } 213 dur, err1 := MaxMySQLTime(fsp).ConvertToDuration() 214 terror.Log(err1) 215 return dur, ErrOverflow.GenWithStackByArgs("Duration", strconv.Itoa(int(number))) 216 } else if number < -TimeMaxValue { 217 dur, err1 := MaxMySQLTime(fsp).ConvertToDuration() 218 terror.Log(err1) 219 dur.Duration = -dur.Duration 220 return dur, ErrOverflow.GenWithStackByArgs("Duration", strconv.Itoa(int(number))) 221 } 222 var neg bool 223 if neg = number < 0; neg { 224 number = -number 225 } 226 227 if number/10000 > TimeMaxHour || number%100 >= 60 || (number/100)%100 >= 60 { 228 return ZeroDuration, errors.Trace(ErrInvalidTimeFormat.GenWithStackByArgs(number)) 229 } 230 t := Time{Time: FromDate(0, 0, 0, int(number/10000), int((number/100)%100), int(number%100), 0), Type: mysql.TypeDuration, Fsp: fsp} 231 dur, err := t.ConvertToDuration() 232 if err != nil { 233 return ZeroDuration, errors.Trace(err) 234 } 235 if neg { 236 dur.Duration = -dur.Duration 237 } 238 return dur, nil 239 } 240 241 // getValidIntPrefix gets prefix of the string which can be successfully parsed as int. 242 func getValidIntPrefix(sc *stmtctx.StatementContext, str string) (string, error) { 243 floatPrefix, err := getValidFloatPrefix(sc, str) 244 if err != nil { 245 return floatPrefix, errors.Trace(err) 246 } 247 return floatStrToIntStr(sc, floatPrefix, str) 248 } 249 250 // roundIntStr is to round int string base on the number following dot. 251 func roundIntStr(numNextDot byte, intStr string) string { 252 if numNextDot < '5' { 253 return intStr 254 } 255 retStr := []byte(intStr) 256 for i := len(intStr) - 1; i >= 0; i-- { 257 if retStr[i] != '9' { 258 retStr[i]++ 259 break 260 } 261 if i == 0 { 262 retStr[i] = '1' 263 retStr = append(retStr, '0') 264 break 265 } 266 retStr[i] = '0' 267 } 268 return string(retStr) 269 } 270 271 // floatStrToIntStr converts a valid float string into valid integer string which can be parsed by 272 // strconv.ParseInt, we can't parse float first then convert it to string because precision will 273 // be lost. For example, the string value "18446744073709551615" which is the max number of unsigned 274 // int will cause some precision to lose. intStr[0] may be a positive and negative sign like '+' or '-'. 275 func floatStrToIntStr(sc *stmtctx.StatementContext, validFloat string, oriStr string) (intStr string, _ error) { 276 var dotIdx = -1 277 var eIdx = -1 278 for i := 0; i < len(validFloat); i++ { 279 switch validFloat[i] { 280 case '.': 281 dotIdx = i 282 case 'e', 'E': 283 eIdx = i 284 } 285 } 286 if eIdx == -1 { 287 if dotIdx == -1 { 288 return validFloat, nil 289 } 290 var digits []byte 291 if validFloat[0] == '-' || validFloat[0] == '+' { 292 dotIdx-- 293 digits = []byte(validFloat[1:]) 294 } else { 295 digits = []byte(validFloat) 296 } 297 if dotIdx == 0 { 298 intStr = "0" 299 } else { 300 intStr = string(digits)[:dotIdx] 301 } 302 if len(digits) > dotIdx+1 { 303 intStr = roundIntStr(digits[dotIdx+1], intStr) 304 } 305 if (len(intStr) > 1 || intStr[0] != '0') && validFloat[0] == '-' { 306 intStr = "-" + intStr 307 } 308 return intStr, nil 309 } 310 var intCnt int 311 digits := make([]byte, 0, len(validFloat)) 312 if dotIdx == -1 { 313 digits = append(digits, validFloat[:eIdx]...) 314 intCnt = len(digits) 315 } else { 316 digits = append(digits, validFloat[:dotIdx]...) 317 intCnt = len(digits) 318 digits = append(digits, validFloat[dotIdx+1:eIdx]...) 319 } 320 exp, err := strconv.Atoi(validFloat[eIdx+1:]) 321 if err != nil { 322 return validFloat, errors.Trace(err) 323 } 324 if exp > 0 && int64(intCnt) > (math.MaxInt64-int64(exp)) { 325 // (exp + incCnt) overflows MaxInt64. 326 sc.AppendWarning(ErrOverflow.GenWithStackByArgs("BIGINT", oriStr)) 327 return validFloat[:eIdx], nil 328 } 329 intCnt += exp 330 if intCnt <= 0 { 331 intStr = "0" 332 if intCnt == 0 && len(digits) > 0 { 333 intStr = roundIntStr(digits[0], intStr) 334 } 335 return intStr, nil 336 } 337 if intCnt == 1 && (digits[0] == '-' || digits[0] == '+') { 338 intStr = "0" 339 if len(digits) > 1 { 340 intStr = roundIntStr(digits[1], intStr) 341 } 342 if intStr[0] == '1' { 343 intStr = string(digits[:1]) + intStr 344 } 345 return intStr, nil 346 } 347 if intCnt <= len(digits) { 348 intStr = string(digits[:intCnt]) 349 if intCnt < len(digits) { 350 intStr = roundIntStr(digits[intCnt], intStr) 351 } 352 } else { 353 // convert scientific notation decimal number 354 extraZeroCount := intCnt - len(digits) 355 if extraZeroCount > 20 { 356 // Append overflow warning and return to avoid allocating too much memory. 357 sc.AppendWarning(ErrOverflow.GenWithStackByArgs("BIGINT", oriStr)) 358 return validFloat[:eIdx], nil 359 } 360 intStr = string(digits) + strings.Repeat("0", extraZeroCount) 361 } 362 return intStr, nil 363 } 364 365 // StrToFloat converts a string to a float64 at the best-effort. 366 func StrToFloat(sc *stmtctx.StatementContext, str string) (float64, error) { 367 str = strings.TrimSpace(str) 368 validStr, err := getValidFloatPrefix(sc, str) 369 f, err1 := strconv.ParseFloat(validStr, 64) 370 if err1 != nil { 371 if err2, ok := err1.(*strconv.NumError); ok { 372 // value will truncate to MAX/MIN if out of range. 373 if err2.Err == strconv.ErrRange { 374 err1 = sc.HandleTruncate(ErrTruncatedWrongVal.GenWithStackByArgs("DOUBLE", str)) 375 if math.IsInf(f, 1) { 376 f = math.MaxFloat64 377 } else if math.IsInf(f, -1) { 378 f = -math.MaxFloat64 379 } 380 } 381 } 382 return f, errors.Trace(err1) 383 } 384 return f, errors.Trace(err) 385 } 386 387 // ConvertJSONToInt casts JSON into int64. 388 func ConvertJSONToInt(sc *stmtctx.StatementContext, j json.BinaryJSON, unsigned bool) (int64, error) { 389 switch j.TypeCode { 390 case json.TypeCodeObject, json.TypeCodeArray: 391 return 0, nil 392 case json.TypeCodeLiteral: 393 switch j.Value[0] { 394 case json.LiteralNil, json.LiteralFalse: 395 return 0, nil 396 default: 397 return 1, nil 398 } 399 case json.TypeCodeInt64, json.TypeCodeUint64: 400 return j.GetInt64(), nil 401 case json.TypeCodeFloat64: 402 f := j.GetFloat64() 403 if !unsigned { 404 lBound := SignedLowerBound[mysql.TypeLonglong] 405 uBound := SignedUpperBound[mysql.TypeLonglong] 406 return ConvertFloatToInt(f, lBound, uBound, mysql.TypeDouble) 407 } 408 bound := UnsignedUpperBound[mysql.TypeLonglong] 409 u, err := ConvertFloatToUint(sc, f, bound, mysql.TypeDouble) 410 return int64(u), errors.Trace(err) 411 case json.TypeCodeString: 412 str := string(hack.String(j.GetString())) 413 return StrToInt(sc, str) 414 } 415 return 0, errors.New("Unknown type code in JSON") 416 } 417 418 // ConvertJSONToFloat casts JSON into float64. 419 func ConvertJSONToFloat(sc *stmtctx.StatementContext, j json.BinaryJSON) (float64, error) { 420 switch j.TypeCode { 421 case json.TypeCodeObject, json.TypeCodeArray: 422 return 0, nil 423 case json.TypeCodeLiteral: 424 switch j.Value[0] { 425 case json.LiteralNil, json.LiteralFalse: 426 return 0, nil 427 default: 428 return 1, nil 429 } 430 case json.TypeCodeInt64: 431 return float64(j.GetInt64()), nil 432 case json.TypeCodeUint64: 433 u, err := ConvertIntToUint(sc, j.GetInt64(), UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeLonglong) 434 return float64(u), errors.Trace(err) 435 case json.TypeCodeFloat64: 436 return j.GetFloat64(), nil 437 case json.TypeCodeString: 438 str := string(hack.String(j.GetString())) 439 return StrToFloat(sc, str) 440 } 441 return 0, errors.New("Unknown type code in JSON") 442 } 443 444 // ConvertJSONToDecimal casts JSON into decimal. 445 func ConvertJSONToDecimal(sc *stmtctx.StatementContext, j json.BinaryJSON) (*MyDecimal, error) { 446 res := new(MyDecimal) 447 if j.TypeCode != json.TypeCodeString { 448 f64, err := ConvertJSONToFloat(sc, j) 449 if err != nil { 450 return res, errors.Trace(err) 451 } 452 err = res.FromFloat64(f64) 453 return res, errors.Trace(err) 454 } 455 err := sc.HandleTruncate(res.FromString([]byte(j.GetString()))) 456 return res, errors.Trace(err) 457 } 458 459 // getValidFloatPrefix gets prefix of string which can be successfully parsed as float. 460 func getValidFloatPrefix(sc *stmtctx.StatementContext, s string) (valid string, err error) { 461 var ( 462 sawDot bool 463 sawDigit bool 464 validLen int 465 eIdx int 466 ) 467 for i := 0; i < len(s); i++ { 468 c := s[i] 469 if c == '+' || c == '-' { 470 if i != 0 && i != eIdx+1 { // "1e+1" is valid. 471 break 472 } 473 } else if c == '.' { 474 if sawDot || eIdx > 0 { // "1.1." or "1e1.1" 475 break 476 } 477 sawDot = true 478 if sawDigit { // "123." is valid. 479 validLen = i + 1 480 } 481 } else if c == 'e' || c == 'E' { 482 if !sawDigit { // "+.e" 483 break 484 } 485 if eIdx != 0 { // "1e5e" 486 break 487 } 488 eIdx = i 489 } else if c < '0' || c > '9' { 490 break 491 } else { 492 sawDigit = true 493 validLen = i + 1 494 } 495 } 496 valid = s[:validLen] 497 if valid == "" { 498 valid = "0" 499 } 500 if validLen == 0 || validLen != len(s) { 501 err = errors.Trace(handleTruncateError(sc)) 502 } 503 return valid, err 504 } 505 506 // ToString converts an interface to a string. 507 func ToString(value interface{}) (string, error) { 508 switch v := value.(type) { 509 case bool: 510 if v { 511 return "1", nil 512 } 513 return "0", nil 514 case int: 515 return strconv.FormatInt(int64(v), 10), nil 516 case int64: 517 return strconv.FormatInt(v, 10), nil 518 case uint64: 519 return strconv.FormatUint(v, 10), nil 520 case float32: 521 return strconv.FormatFloat(float64(v), 'f', -1, 32), nil 522 case float64: 523 return strconv.FormatFloat(v, 'f', -1, 64), nil 524 case string: 525 return v, nil 526 case []byte: 527 return string(v), nil 528 case Time: 529 return v.String(), nil 530 case Duration: 531 return v.String(), nil 532 case *MyDecimal: 533 return v.String(), nil 534 case BinaryLiteral: 535 return v.ToString(), nil 536 case Enum: 537 return v.String(), nil 538 case Set: 539 return v.String(), nil 540 default: 541 return "", errors.Errorf("cannot convert %v(type %T) to string", value, value) 542 } 543 }