github.com/XiaoMi/Gaea@v1.2.5/parser/tidb-types/convert.go (about)

     1  // Copyright 2014 The ql Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSES/QL-LICENSE file.
     4  
     5  // Copyright 2015 PingCAP, Inc.
     6  //
     7  // Licensed under the Apache License, Version 2.0 (the "License");
     8  // you may not use this file except in compliance with the License.
     9  // You may obtain a copy of the License at
    10  //
    11  //     http://www.apache.org/licenses/LICENSE-2.0
    12  //
    13  // Unless required by applicable law or agreed to in writing, software
    14  // distributed under the License is distributed on an "AS IS" BASIS,
    15  // See the License for the specific language governing permissions and
    16  // limitations under the License.
    17  
    18  package types
    19  
    20  import (
    21  	"math"
    22  	"strconv"
    23  	"strings"
    24  
    25  	"github.com/pingcap/errors"
    26  
    27  	"github.com/XiaoMi/Gaea/mysql"
    28  	"github.com/XiaoMi/Gaea/parser/stmtctx"
    29  	"github.com/XiaoMi/Gaea/parser/terror"
    30  	"github.com/XiaoMi/Gaea/parser/tidb-types/json"
    31  	"github.com/XiaoMi/Gaea/util/hack"
    32  )
    33  
    34  func truncateStr(str string, flen int) string {
    35  	if flen != UnspecifiedLength && len(str) > flen {
    36  		str = str[:flen]
    37  	}
    38  	return str
    39  }
    40  
    41  // UnsignedUpperBound indicates the max uint64 values of different mysql types.
    42  var UnsignedUpperBound = map[byte]uint64{
    43  	mysql.TypeTiny:     math.MaxUint8,
    44  	mysql.TypeShort:    math.MaxUint16,
    45  	mysql.TypeInt24:    mysql.MaxUint24,
    46  	mysql.TypeLong:     math.MaxUint32,
    47  	mysql.TypeLonglong: math.MaxUint64,
    48  	mysql.TypeBit:      math.MaxUint64,
    49  	mysql.TypeEnum:     math.MaxUint64,
    50  	mysql.TypeSet:      math.MaxUint64,
    51  }
    52  
    53  // SignedUpperBound indicates the max int64 values of different mysql types.
    54  var SignedUpperBound = map[byte]int64{
    55  	mysql.TypeTiny:     math.MaxInt8,
    56  	mysql.TypeShort:    math.MaxInt16,
    57  	mysql.TypeInt24:    mysql.MaxInt24,
    58  	mysql.TypeLong:     math.MaxInt32,
    59  	mysql.TypeLonglong: math.MaxInt64,
    60  }
    61  
    62  // SignedLowerBound indicates the min int64 values of different mysql types.
    63  var SignedLowerBound = map[byte]int64{
    64  	mysql.TypeTiny:     math.MinInt8,
    65  	mysql.TypeShort:    math.MinInt16,
    66  	mysql.TypeInt24:    mysql.MinInt24,
    67  	mysql.TypeLong:     math.MinInt32,
    68  	mysql.TypeLonglong: math.MinInt64,
    69  }
    70  
    71  // ConvertFloatToInt converts a float64 value to a int value.
    72  func ConvertFloatToInt(fval float64, lowerBound, upperBound int64, tp byte) (int64, error) {
    73  	val := RoundFloat(fval)
    74  	if val < float64(lowerBound) {
    75  		return lowerBound, overflow(val, tp)
    76  	}
    77  
    78  	if val >= float64(upperBound) {
    79  		if val == float64(upperBound) {
    80  			return upperBound, nil
    81  		}
    82  		return upperBound, overflow(val, tp)
    83  	}
    84  	return int64(val), nil
    85  }
    86  
    87  // ConvertIntToInt converts an int value to another int value of different precision.
    88  func ConvertIntToInt(val int64, lowerBound int64, upperBound int64, tp byte) (int64, error) {
    89  	if val < lowerBound {
    90  		return lowerBound, overflow(val, tp)
    91  	}
    92  
    93  	if val > upperBound {
    94  		return upperBound, overflow(val, tp)
    95  	}
    96  
    97  	return val, nil
    98  }
    99  
   100  // ConvertUintToInt converts an uint value to an int value.
   101  func ConvertUintToInt(val uint64, upperBound int64, tp byte) (int64, error) {
   102  	if val > uint64(upperBound) {
   103  		return upperBound, overflow(val, tp)
   104  	}
   105  
   106  	return int64(val), nil
   107  }
   108  
   109  // ConvertIntToUint converts an int value to an uint value.
   110  func ConvertIntToUint(sc *stmtctx.StatementContext, val int64, upperBound uint64, tp byte) (uint64, error) {
   111  	if sc.ShouldClipToZero() && val < 0 {
   112  		return 0, overflow(val, tp)
   113  	}
   114  
   115  	if uint64(val) > upperBound {
   116  		return upperBound, overflow(val, tp)
   117  	}
   118  
   119  	return uint64(val), nil
   120  }
   121  
   122  // ConvertUintToUint converts an uint value to another uint value of different precision.
   123  func ConvertUintToUint(val uint64, upperBound uint64, tp byte) (uint64, error) {
   124  	if val > upperBound {
   125  		return upperBound, overflow(val, tp)
   126  	}
   127  
   128  	return val, nil
   129  }
   130  
   131  // ConvertFloatToUint converts a float value to an uint value.
   132  func ConvertFloatToUint(sc *stmtctx.StatementContext, fval float64, upperBound uint64, tp byte) (uint64, error) {
   133  	val := RoundFloat(fval)
   134  	if val < 0 {
   135  		if sc.ShouldClipToZero() {
   136  			return 0, overflow(val, tp)
   137  		}
   138  		return uint64(int64(val)), overflow(val, tp)
   139  	}
   140  
   141  	if val > float64(upperBound) {
   142  		return upperBound, overflow(val, tp)
   143  	}
   144  	return uint64(val), nil
   145  }
   146  
   147  // StrToInt converts a string to an integer at the best-effort.
   148  func StrToInt(sc *stmtctx.StatementContext, str string) (int64, error) {
   149  	str = strings.TrimSpace(str)
   150  	validPrefix, err := getValidIntPrefix(sc, str)
   151  	iVal, err1 := strconv.ParseInt(validPrefix, 10, 64)
   152  	if err1 != nil {
   153  		return iVal, ErrOverflow.GenWithStackByArgs("BIGINT", validPrefix)
   154  	}
   155  	return iVal, errors.Trace(err)
   156  }
   157  
   158  // StrToUint converts a string to an unsigned integer at the best-effortt.
   159  func StrToUint(sc *stmtctx.StatementContext, str string) (uint64, error) {
   160  	str = strings.TrimSpace(str)
   161  	validPrefix, err := getValidIntPrefix(sc, str)
   162  	if validPrefix[0] == '+' {
   163  		validPrefix = validPrefix[1:]
   164  	}
   165  	uVal, err1 := strconv.ParseUint(validPrefix, 10, 64)
   166  	if err1 != nil {
   167  		return uVal, ErrOverflow.GenWithStackByArgs("BIGINT UNSIGNED", validPrefix)
   168  	}
   169  	return uVal, errors.Trace(err)
   170  }
   171  
   172  // StrToDateTime converts str to MySQL DateTime.
   173  func StrToDateTime(sc *stmtctx.StatementContext, str string, fsp int) (Time, error) {
   174  	return ParseTime(sc, str, mysql.TypeDatetime, fsp)
   175  }
   176  
   177  // StrToDuration converts str to Duration. It returns Duration in normal case,
   178  // and returns Time when str is in datetime format.
   179  // when isDuration is true, the d is returned, when it is false, the t is returned.
   180  // See https://dev.mysql.com/doc/refman/5.5/en/date-and-time-literals.html.
   181  func StrToDuration(sc *stmtctx.StatementContext, str string, fsp int) (d Duration, t Time, isDuration bool, err error) {
   182  	str = strings.TrimSpace(str)
   183  	length := len(str)
   184  	if length > 0 && str[0] == '-' {
   185  		length--
   186  	}
   187  	// Timestamp format is 'YYYYMMDDHHMMSS' or 'YYMMDDHHMMSS', which length is 12.
   188  	// See #3923, it explains what we do here.
   189  	if length >= 12 {
   190  		t, err = StrToDateTime(sc, str, fsp)
   191  		if err == nil {
   192  			return d, t, false, nil
   193  		}
   194  	}
   195  
   196  	d, err = ParseDuration(sc, str, fsp)
   197  	if ErrTruncatedWrongVal.Equal(err) {
   198  		err = sc.HandleTruncate(err)
   199  	}
   200  	return d, t, true, errors.Trace(err)
   201  }
   202  
   203  // NumberToDuration converts number to Duration.
   204  func NumberToDuration(number int64, fsp int) (Duration, error) {
   205  	if number > TimeMaxValue {
   206  		// Try to parse DATETIME.
   207  		if number >= 10000000000 { // '2001-00-00 00-00-00'
   208  			if t, err := ParseDatetimeFromNum(nil, number); err == nil {
   209  				dur, err1 := t.ConvertToDuration()
   210  				return dur, errors.Trace(err1)
   211  			}
   212  		}
   213  		dur, err1 := MaxMySQLTime(fsp).ConvertToDuration()
   214  		terror.Log(err1)
   215  		return dur, ErrOverflow.GenWithStackByArgs("Duration", strconv.Itoa(int(number)))
   216  	} else if number < -TimeMaxValue {
   217  		dur, err1 := MaxMySQLTime(fsp).ConvertToDuration()
   218  		terror.Log(err1)
   219  		dur.Duration = -dur.Duration
   220  		return dur, ErrOverflow.GenWithStackByArgs("Duration", strconv.Itoa(int(number)))
   221  	}
   222  	var neg bool
   223  	if neg = number < 0; neg {
   224  		number = -number
   225  	}
   226  
   227  	if number/10000 > TimeMaxHour || number%100 >= 60 || (number/100)%100 >= 60 {
   228  		return ZeroDuration, errors.Trace(ErrInvalidTimeFormat.GenWithStackByArgs(number))
   229  	}
   230  	t := Time{Time: FromDate(0, 0, 0, int(number/10000), int((number/100)%100), int(number%100), 0), Type: mysql.TypeDuration, Fsp: fsp}
   231  	dur, err := t.ConvertToDuration()
   232  	if err != nil {
   233  		return ZeroDuration, errors.Trace(err)
   234  	}
   235  	if neg {
   236  		dur.Duration = -dur.Duration
   237  	}
   238  	return dur, nil
   239  }
   240  
   241  // getValidIntPrefix gets prefix of the string which can be successfully parsed as int.
   242  func getValidIntPrefix(sc *stmtctx.StatementContext, str string) (string, error) {
   243  	floatPrefix, err := getValidFloatPrefix(sc, str)
   244  	if err != nil {
   245  		return floatPrefix, errors.Trace(err)
   246  	}
   247  	return floatStrToIntStr(sc, floatPrefix, str)
   248  }
   249  
   250  // roundIntStr is to round int string base on the number following dot.
   251  func roundIntStr(numNextDot byte, intStr string) string {
   252  	if numNextDot < '5' {
   253  		return intStr
   254  	}
   255  	retStr := []byte(intStr)
   256  	for i := len(intStr) - 1; i >= 0; i-- {
   257  		if retStr[i] != '9' {
   258  			retStr[i]++
   259  			break
   260  		}
   261  		if i == 0 {
   262  			retStr[i] = '1'
   263  			retStr = append(retStr, '0')
   264  			break
   265  		}
   266  		retStr[i] = '0'
   267  	}
   268  	return string(retStr)
   269  }
   270  
   271  // floatStrToIntStr converts a valid float string into valid integer string which can be parsed by
   272  // strconv.ParseInt, we can't parse float first then convert it to string because precision will
   273  // be lost. For example, the string value "18446744073709551615" which is the max number of unsigned
   274  // int will cause some precision to lose. intStr[0] may be a positive and negative sign like '+' or '-'.
   275  func floatStrToIntStr(sc *stmtctx.StatementContext, validFloat string, oriStr string) (intStr string, _ error) {
   276  	var dotIdx = -1
   277  	var eIdx = -1
   278  	for i := 0; i < len(validFloat); i++ {
   279  		switch validFloat[i] {
   280  		case '.':
   281  			dotIdx = i
   282  		case 'e', 'E':
   283  			eIdx = i
   284  		}
   285  	}
   286  	if eIdx == -1 {
   287  		if dotIdx == -1 {
   288  			return validFloat, nil
   289  		}
   290  		var digits []byte
   291  		if validFloat[0] == '-' || validFloat[0] == '+' {
   292  			dotIdx--
   293  			digits = []byte(validFloat[1:])
   294  		} else {
   295  			digits = []byte(validFloat)
   296  		}
   297  		if dotIdx == 0 {
   298  			intStr = "0"
   299  		} else {
   300  			intStr = string(digits)[:dotIdx]
   301  		}
   302  		if len(digits) > dotIdx+1 {
   303  			intStr = roundIntStr(digits[dotIdx+1], intStr)
   304  		}
   305  		if (len(intStr) > 1 || intStr[0] != '0') && validFloat[0] == '-' {
   306  			intStr = "-" + intStr
   307  		}
   308  		return intStr, nil
   309  	}
   310  	var intCnt int
   311  	digits := make([]byte, 0, len(validFloat))
   312  	if dotIdx == -1 {
   313  		digits = append(digits, validFloat[:eIdx]...)
   314  		intCnt = len(digits)
   315  	} else {
   316  		digits = append(digits, validFloat[:dotIdx]...)
   317  		intCnt = len(digits)
   318  		digits = append(digits, validFloat[dotIdx+1:eIdx]...)
   319  	}
   320  	exp, err := strconv.Atoi(validFloat[eIdx+1:])
   321  	if err != nil {
   322  		return validFloat, errors.Trace(err)
   323  	}
   324  	if exp > 0 && int64(intCnt) > (math.MaxInt64-int64(exp)) {
   325  		// (exp + incCnt) overflows MaxInt64.
   326  		sc.AppendWarning(ErrOverflow.GenWithStackByArgs("BIGINT", oriStr))
   327  		return validFloat[:eIdx], nil
   328  	}
   329  	intCnt += exp
   330  	if intCnt <= 0 {
   331  		intStr = "0"
   332  		if intCnt == 0 && len(digits) > 0 {
   333  			intStr = roundIntStr(digits[0], intStr)
   334  		}
   335  		return intStr, nil
   336  	}
   337  	if intCnt == 1 && (digits[0] == '-' || digits[0] == '+') {
   338  		intStr = "0"
   339  		if len(digits) > 1 {
   340  			intStr = roundIntStr(digits[1], intStr)
   341  		}
   342  		if intStr[0] == '1' {
   343  			intStr = string(digits[:1]) + intStr
   344  		}
   345  		return intStr, nil
   346  	}
   347  	if intCnt <= len(digits) {
   348  		intStr = string(digits[:intCnt])
   349  		if intCnt < len(digits) {
   350  			intStr = roundIntStr(digits[intCnt], intStr)
   351  		}
   352  	} else {
   353  		// convert scientific notation decimal number
   354  		extraZeroCount := intCnt - len(digits)
   355  		if extraZeroCount > 20 {
   356  			// Append overflow warning and return to avoid allocating too much memory.
   357  			sc.AppendWarning(ErrOverflow.GenWithStackByArgs("BIGINT", oriStr))
   358  			return validFloat[:eIdx], nil
   359  		}
   360  		intStr = string(digits) + strings.Repeat("0", extraZeroCount)
   361  	}
   362  	return intStr, nil
   363  }
   364  
   365  // StrToFloat converts a string to a float64 at the best-effort.
   366  func StrToFloat(sc *stmtctx.StatementContext, str string) (float64, error) {
   367  	str = strings.TrimSpace(str)
   368  	validStr, err := getValidFloatPrefix(sc, str)
   369  	f, err1 := strconv.ParseFloat(validStr, 64)
   370  	if err1 != nil {
   371  		if err2, ok := err1.(*strconv.NumError); ok {
   372  			// value will truncate to MAX/MIN if out of range.
   373  			if err2.Err == strconv.ErrRange {
   374  				err1 = sc.HandleTruncate(ErrTruncatedWrongVal.GenWithStackByArgs("DOUBLE", str))
   375  				if math.IsInf(f, 1) {
   376  					f = math.MaxFloat64
   377  				} else if math.IsInf(f, -1) {
   378  					f = -math.MaxFloat64
   379  				}
   380  			}
   381  		}
   382  		return f, errors.Trace(err1)
   383  	}
   384  	return f, errors.Trace(err)
   385  }
   386  
   387  // ConvertJSONToInt casts JSON into int64.
   388  func ConvertJSONToInt(sc *stmtctx.StatementContext, j json.BinaryJSON, unsigned bool) (int64, error) {
   389  	switch j.TypeCode {
   390  	case json.TypeCodeObject, json.TypeCodeArray:
   391  		return 0, nil
   392  	case json.TypeCodeLiteral:
   393  		switch j.Value[0] {
   394  		case json.LiteralNil, json.LiteralFalse:
   395  			return 0, nil
   396  		default:
   397  			return 1, nil
   398  		}
   399  	case json.TypeCodeInt64, json.TypeCodeUint64:
   400  		return j.GetInt64(), nil
   401  	case json.TypeCodeFloat64:
   402  		f := j.GetFloat64()
   403  		if !unsigned {
   404  			lBound := SignedLowerBound[mysql.TypeLonglong]
   405  			uBound := SignedUpperBound[mysql.TypeLonglong]
   406  			return ConvertFloatToInt(f, lBound, uBound, mysql.TypeDouble)
   407  		}
   408  		bound := UnsignedUpperBound[mysql.TypeLonglong]
   409  		u, err := ConvertFloatToUint(sc, f, bound, mysql.TypeDouble)
   410  		return int64(u), errors.Trace(err)
   411  	case json.TypeCodeString:
   412  		str := string(hack.String(j.GetString()))
   413  		return StrToInt(sc, str)
   414  	}
   415  	return 0, errors.New("Unknown type code in JSON")
   416  }
   417  
   418  // ConvertJSONToFloat casts JSON into float64.
   419  func ConvertJSONToFloat(sc *stmtctx.StatementContext, j json.BinaryJSON) (float64, error) {
   420  	switch j.TypeCode {
   421  	case json.TypeCodeObject, json.TypeCodeArray:
   422  		return 0, nil
   423  	case json.TypeCodeLiteral:
   424  		switch j.Value[0] {
   425  		case json.LiteralNil, json.LiteralFalse:
   426  			return 0, nil
   427  		default:
   428  			return 1, nil
   429  		}
   430  	case json.TypeCodeInt64:
   431  		return float64(j.GetInt64()), nil
   432  	case json.TypeCodeUint64:
   433  		u, err := ConvertIntToUint(sc, j.GetInt64(), UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeLonglong)
   434  		return float64(u), errors.Trace(err)
   435  	case json.TypeCodeFloat64:
   436  		return j.GetFloat64(), nil
   437  	case json.TypeCodeString:
   438  		str := string(hack.String(j.GetString()))
   439  		return StrToFloat(sc, str)
   440  	}
   441  	return 0, errors.New("Unknown type code in JSON")
   442  }
   443  
   444  // ConvertJSONToDecimal casts JSON into decimal.
   445  func ConvertJSONToDecimal(sc *stmtctx.StatementContext, j json.BinaryJSON) (*MyDecimal, error) {
   446  	res := new(MyDecimal)
   447  	if j.TypeCode != json.TypeCodeString {
   448  		f64, err := ConvertJSONToFloat(sc, j)
   449  		if err != nil {
   450  			return res, errors.Trace(err)
   451  		}
   452  		err = res.FromFloat64(f64)
   453  		return res, errors.Trace(err)
   454  	}
   455  	err := sc.HandleTruncate(res.FromString([]byte(j.GetString())))
   456  	return res, errors.Trace(err)
   457  }
   458  
   459  // getValidFloatPrefix gets prefix of string which can be successfully parsed as float.
   460  func getValidFloatPrefix(sc *stmtctx.StatementContext, s string) (valid string, err error) {
   461  	var (
   462  		sawDot   bool
   463  		sawDigit bool
   464  		validLen int
   465  		eIdx     int
   466  	)
   467  	for i := 0; i < len(s); i++ {
   468  		c := s[i]
   469  		if c == '+' || c == '-' {
   470  			if i != 0 && i != eIdx+1 { // "1e+1" is valid.
   471  				break
   472  			}
   473  		} else if c == '.' {
   474  			if sawDot || eIdx > 0 { // "1.1." or "1e1.1"
   475  				break
   476  			}
   477  			sawDot = true
   478  			if sawDigit { // "123." is valid.
   479  				validLen = i + 1
   480  			}
   481  		} else if c == 'e' || c == 'E' {
   482  			if !sawDigit { // "+.e"
   483  				break
   484  			}
   485  			if eIdx != 0 { // "1e5e"
   486  				break
   487  			}
   488  			eIdx = i
   489  		} else if c < '0' || c > '9' {
   490  			break
   491  		} else {
   492  			sawDigit = true
   493  			validLen = i + 1
   494  		}
   495  	}
   496  	valid = s[:validLen]
   497  	if valid == "" {
   498  		valid = "0"
   499  	}
   500  	if validLen == 0 || validLen != len(s) {
   501  		err = errors.Trace(handleTruncateError(sc))
   502  	}
   503  	return valid, err
   504  }
   505  
   506  // ToString converts an interface to a string.
   507  func ToString(value interface{}) (string, error) {
   508  	switch v := value.(type) {
   509  	case bool:
   510  		if v {
   511  			return "1", nil
   512  		}
   513  		return "0", nil
   514  	case int:
   515  		return strconv.FormatInt(int64(v), 10), nil
   516  	case int64:
   517  		return strconv.FormatInt(v, 10), nil
   518  	case uint64:
   519  		return strconv.FormatUint(v, 10), nil
   520  	case float32:
   521  		return strconv.FormatFloat(float64(v), 'f', -1, 32), nil
   522  	case float64:
   523  		return strconv.FormatFloat(v, 'f', -1, 64), nil
   524  	case string:
   525  		return v, nil
   526  	case []byte:
   527  		return string(v), nil
   528  	case Time:
   529  		return v.String(), nil
   530  	case Duration:
   531  		return v.String(), nil
   532  	case *MyDecimal:
   533  		return v.String(), nil
   534  	case BinaryLiteral:
   535  		return v.ToString(), nil
   536  	case Enum:
   537  		return v.String(), nil
   538  	case Set:
   539  		return v.String(), nil
   540  	default:
   541  		return "", errors.Errorf("cannot convert %v(type %T) to string", value, value)
   542  	}
   543  }