github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/convert.go (about) 1 // Copyright 2020-2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package expression 16 17 import ( 18 "encoding/hex" 19 "fmt" 20 "strconv" 21 "strings" 22 "time" 23 24 "github.com/sirupsen/logrus" 25 "gopkg.in/src-d/go-errors.v1" 26 27 "github.com/dolthub/go-mysql-server/sql" 28 "github.com/dolthub/go-mysql-server/sql/types" 29 ) 30 31 // ErrConvertExpression is returned when a conversion is not possible. 32 var ErrConvertExpression = errors.NewKind("expression '%v': couldn't convert to %v") 33 34 const ( 35 // ConvertToBinary is a conversion to binary. 36 ConvertToBinary = "binary" 37 // ConvertToChar is a conversion to char. 38 ConvertToChar = "char" 39 // ConvertToNChar is a conversion to nchar. 40 ConvertToNChar = "nchar" 41 // ConvertToDate is a conversion to date. 42 ConvertToDate = "date" 43 // ConvertToDatetime is a conversion to datetime. 44 ConvertToDatetime = "datetime" 45 // ConvertToDecimal is a conversion to decimal. 46 ConvertToDecimal = "decimal" 47 // ConvertToFloat is a conversion to float. 48 ConvertToFloat = "float" 49 // ConvertToDouble is a conversion to double. 50 ConvertToDouble = "double" 51 // ConvertToJSON is a conversion to json. 52 ConvertToJSON = "json" 53 // ConvertToReal is a conversion to double. 54 ConvertToReal = "real" 55 // ConvertToSigned is a conversion to signed. 56 ConvertToSigned = "signed" 57 // ConvertToTime is a conversion to time. 58 ConvertToTime = "time" 59 // ConvertToUnsigned is a conversion to unsigned. 60 ConvertToUnsigned = "unsigned" 61 ) 62 63 // Convert represent a CAST(x AS T) or CONVERT(x, T) operation that casts x expression to type T. 64 type Convert struct { 65 UnaryExpression 66 // castToType is a string representation of the base type to which we are casting (e.g. "char", "float", "decimal") 67 castToType string 68 // typeLength is the optional length parameter for types that support it (e.g. "char(10)") 69 typeLength int 70 // typeScale is the optional scale parameter for types that support it (e.g. "decimal(10, 2)") 71 typeScale int 72 // cachedDecimalType is the cached Decimal type for this convert expression. Because new Decimal types 73 // must be created with their specific scale and precision values, unlike other types, we cache the created 74 // type to avoid re-creating it on every call to Type(). 75 cachedDecimalType sql.DecimalType 76 } 77 78 var _ sql.Expression = (*Convert)(nil) 79 var _ sql.CollationCoercible = (*Convert)(nil) 80 81 // NewConvert creates a new Convert expression that will attempt to convert the specified expression |expr| into the 82 // |castToType| type. All optional parameters (i.e. typeLength, typeScale, and charset) are omitted and initialized 83 // to their zero values. 84 func NewConvert(expr sql.Expression, castToType string) *Convert { 85 disableRounding(expr) 86 return &Convert{ 87 UnaryExpression: UnaryExpression{Child: expr}, 88 castToType: strings.ToLower(castToType), 89 } 90 } 91 92 // NewConvertWithLengthAndScale creates a new Convert expression that will attempt to convert |expr| into the 93 // |castToType| type, with |typeLength| specifying a length constraint of the converted type, and |typeScale| specifying 94 // a scale constraint of the converted type. 95 func NewConvertWithLengthAndScale(expr sql.Expression, castToType string, typeLength, typeScale int) *Convert { 96 disableRounding(expr) 97 return &Convert{ 98 UnaryExpression: UnaryExpression{Child: expr}, 99 castToType: strings.ToLower(castToType), 100 typeLength: typeLength, 101 typeScale: typeScale, 102 } 103 } 104 105 // GetConvertToType returns which type the both left and right values should be converted to. 106 // If neither sql.Type represent number, then converted to string. Otherwise, we try to get 107 // the appropriate type to avoid any precision loss. 108 func GetConvertToType(l, r sql.Type) string { 109 if types.Null == l { 110 return GetConvertToType(r, r) 111 } 112 if types.Null == r { 113 return GetConvertToType(l, l) 114 } 115 116 if !types.IsNumber(l) || !types.IsNumber(r) { 117 return ConvertToChar 118 } 119 120 if types.IsDecimal(l) || types.IsDecimal(r) { 121 return ConvertToDecimal 122 } 123 if types.IsUnsigned(l) && types.IsUnsigned(r) { 124 return ConvertToUnsigned 125 } 126 if types.IsSigned(l) && types.IsSigned(r) { 127 return ConvertToSigned 128 } 129 if types.IsInteger(l) && types.IsInteger(r) { 130 return ConvertToSigned 131 } 132 133 return ConvertToChar 134 } 135 136 // IsNullable implements the Expression interface. 137 func (c *Convert) IsNullable() bool { 138 switch c.castToType { 139 case ConvertToDate, ConvertToDatetime: 140 return true 141 default: 142 return c.Child.IsNullable() 143 } 144 } 145 146 // Type implements the Expression interface. 147 func (c *Convert) Type() sql.Type { 148 switch c.castToType { 149 case ConvertToBinary: 150 return types.LongBlob 151 case ConvertToChar, ConvertToNChar: 152 return types.LongText 153 case ConvertToDate: 154 return types.Date 155 case ConvertToDatetime: 156 return types.DatetimeMaxPrecision 157 case ConvertToDecimal: 158 if c.cachedDecimalType == nil { 159 c.cachedDecimalType = createConvertedDecimalType(c.typeLength, c.typeScale, true) 160 } 161 return c.cachedDecimalType 162 case ConvertToFloat: 163 return types.Float32 164 case ConvertToDouble, ConvertToReal: 165 return types.Float64 166 case ConvertToJSON: 167 return types.JSON 168 case ConvertToSigned: 169 return types.Int64 170 case ConvertToTime: 171 return types.Time 172 case ConvertToUnsigned: 173 return types.Uint64 174 default: 175 return types.Null 176 } 177 } 178 179 // CollationCoercibility implements the interface sql.CollationCoercible. 180 func (c *Convert) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { 181 switch c.castToType { 182 case ConvertToBinary: 183 return sql.Collation_binary, 2 184 case ConvertToChar, ConvertToNChar: 185 return ctx.GetCollation(), 2 186 case ConvertToDate: 187 return sql.Collation_binary, 5 188 case ConvertToDatetime: 189 return sql.Collation_binary, 5 190 case ConvertToDecimal: 191 return sql.Collation_binary, 5 192 case ConvertToDouble, ConvertToReal, ConvertToFloat: 193 return sql.Collation_binary, 5 194 case ConvertToJSON: 195 return ctx.GetCharacterSet().BinaryCollation(), 2 196 case ConvertToSigned: 197 return sql.Collation_binary, 5 198 case ConvertToTime: 199 return sql.Collation_binary, 5 200 case ConvertToUnsigned: 201 return sql.Collation_binary, 5 202 default: 203 return sql.Collation_binary, 7 204 } 205 } 206 207 // String implements the Stringer interface. 208 func (c *Convert) String() string { 209 extraTypeInfo := "" 210 if c.typeLength > 0 { 211 if c.typeScale > 0 { 212 extraTypeInfo = fmt.Sprintf("(%d,%d)", c.typeLength, c.typeScale) 213 } else { 214 extraTypeInfo = fmt.Sprintf("(%d)", c.typeLength) 215 } 216 } 217 return fmt.Sprintf("convert(%v, %v%s)", c.Child, c.castToType, extraTypeInfo) 218 } 219 220 // DebugString implements the Expression interface. 221 func (c *Convert) DebugString() string { 222 pr := sql.NewTreePrinter() 223 _ = pr.WriteNode("convert") 224 children := []string{ 225 fmt.Sprintf("type: %v", c.castToType), 226 } 227 228 if c.typeLength > 0 { 229 children = append(children, fmt.Sprintf("typeLength: %v", c.typeLength)) 230 } 231 232 if c.typeScale > 0 { 233 children = append(children, fmt.Sprintf("typeScale: %v", c.typeScale)) 234 } 235 236 children = append(children, fmt.Sprintf(sql.DebugString(c.Child))) 237 238 _ = pr.WriteChildren(children...) 239 return pr.String() 240 } 241 242 // WithChildren implements the Expression interface. 243 func (c *Convert) WithChildren(children ...sql.Expression) (sql.Expression, error) { 244 if len(children) != 1 { 245 return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), 1) 246 } 247 return NewConvertWithLengthAndScale(children[0], c.castToType, c.typeLength, c.typeScale), nil 248 } 249 250 // Eval implements the Expression interface. 251 func (c *Convert) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { 252 val, err := c.Child.Eval(ctx, row) 253 if err != nil { 254 return nil, err 255 } 256 257 if val == nil { 258 return nil, nil 259 } 260 261 // Should always return nil, and a warning instead 262 casted, err := convertValue(val, c.castToType, c.Child.Type(), c.typeLength, c.typeScale) 263 if err != nil { 264 if c.castToType == ConvertToJSON { 265 return nil, ErrConvertExpression.Wrap(err, c.String(), c.castToType) 266 } 267 ctx.Warn(1292, "Incorrect %s value: %v", c.castToType, val) 268 return nil, nil 269 } 270 271 return casted, nil 272 } 273 274 // convertValue only returns an error if converting to JSON, Date, and Datetime; 275 // the zero value is returned for float types. Nil is returned in all other cases. 276 // If |typeLength| and |typeScale| are 0, they are ignored, otherwise they are used as constraints on the 277 // converted type where applicable (e.g. Char conversion supports only |typeLength|, Decimal conversion supports 278 // |typeLength| and |typeScale|). 279 func convertValue(val interface{}, castTo string, originType sql.Type, typeLength, typeScale int) (interface{}, error) { 280 switch strings.ToLower(castTo) { 281 case ConvertToBinary: 282 b, _, err := types.LongBlob.Convert(val) 283 if err != nil { 284 return nil, nil 285 } 286 if types.IsTextOnly(originType) { 287 // For string types we need to re-encode the string as we want the binary representation of the character set 288 encoder := originType.(sql.StringType).Collation().CharacterSet().Encoder() 289 encodedBytes, ok := encoder.Encode(b.([]byte)) 290 if !ok { 291 return nil, fmt.Errorf("unable to re-encode string to convert to binary") 292 } 293 b = encodedBytes 294 } 295 return truncateConvertedValue(b, typeLength) 296 case ConvertToChar, ConvertToNChar: 297 s, _, err := types.LongText.Convert(val) 298 if err != nil { 299 return nil, nil 300 } 301 return truncateConvertedValue(s, typeLength) 302 case ConvertToDate: 303 _, isTime := val.(time.Time) 304 _, isString := val.(string) 305 _, isBinary := val.([]byte) 306 if !(isTime || isString || isBinary) { 307 return nil, nil 308 } 309 d, _, err := types.Date.Convert(val) 310 if err != nil { 311 return nil, err 312 } 313 return d, nil 314 case ConvertToDatetime: 315 _, isTime := val.(time.Time) 316 _, isString := val.(string) 317 _, isBinary := val.([]byte) 318 if !(isTime || isString || isBinary) { 319 return nil, nil 320 } 321 d, _, err := types.DatetimeMaxPrecision.Convert(val) 322 if err != nil { 323 return nil, err 324 } 325 return d, nil 326 case ConvertToDecimal: 327 value, err := convertHexBlobToDecimalForNumericContext(val, originType) 328 if err != nil { 329 return nil, err 330 } 331 dt := createConvertedDecimalType(typeLength, typeScale, false) 332 d, _, err := dt.Convert(value) 333 if err != nil { 334 return "0", nil 335 } 336 return d, nil 337 case ConvertToFloat: 338 value, err := convertHexBlobToDecimalForNumericContext(val, originType) 339 if err != nil { 340 return nil, err 341 } 342 d, _, err := types.Float32.Convert(value) 343 if err != nil { 344 return types.Float32.Zero(), nil 345 } 346 return d, nil 347 case ConvertToDouble, ConvertToReal: 348 value, err := convertHexBlobToDecimalForNumericContext(val, originType) 349 if err != nil { 350 return nil, err 351 } 352 d, _, err := types.Float64.Convert(value) 353 if err != nil { 354 return types.Float64.Zero(), nil 355 } 356 return d, nil 357 case ConvertToJSON: 358 js, _, err := types.JSON.Convert(val) 359 if err != nil { 360 return nil, err 361 } 362 return js, nil 363 case ConvertToSigned: 364 value, err := convertHexBlobToDecimalForNumericContext(val, originType) 365 if err != nil { 366 return nil, err 367 } 368 num, _, err := types.Int64.Convert(value) 369 if err != nil { 370 return types.Int64.Zero(), nil 371 } 372 373 return num, nil 374 case ConvertToTime: 375 t, _, err := types.Time.Convert(val) 376 if err != nil { 377 return nil, nil 378 } 379 return t, nil 380 case ConvertToUnsigned: 381 value, err := convertHexBlobToDecimalForNumericContext(val, originType) 382 if err != nil { 383 return nil, err 384 } 385 num, _, err := types.Uint64.Convert(value) 386 if err != nil { 387 num, _, err = types.Int64.Convert(value) 388 if err != nil { 389 return types.Uint64.Zero(), nil 390 } 391 return uint64(num.(int64)), nil 392 } 393 return num, nil 394 default: 395 return nil, nil 396 } 397 } 398 399 // truncateConvertedValue truncates |val| to the specified |typeLength| if |val| 400 // is a string or byte slice. If the typeLength is 0, or if it is greater than 401 // the length of |val|, then |val| is simply returned as is. If |val| is not a 402 // string or []byte, then an error is returned. 403 func truncateConvertedValue(val interface{}, typeLength int) (interface{}, error) { 404 if typeLength <= 0 { 405 return val, nil 406 } 407 408 switch v := val.(type) { 409 case []byte: 410 if len(v) <= typeLength { 411 typeLength = len(v) 412 } 413 return v[:typeLength], nil 414 case string: 415 if len(v) <= typeLength { 416 typeLength = len(v) 417 } 418 return v[:typeLength], nil 419 default: 420 return nil, fmt.Errorf("unsupported type for truncation: %T", val) 421 } 422 } 423 424 // createConvertedDecimalType creates a new Decimal type with the specified |precision| and |scale|. If a Decimal 425 // type cannot be created from the values specified, the internal Decimal type is returned. If |logErrors| is true, 426 // an error will also logged to the standard logger. (Setting |logErrors| to false, allows the caller to prevent 427 // spurious error message from being logged multiple times for the same error.) This function is intended to be 428 // used in places where an error cannot be returned (e.g. Node.Type() implementations), hence why it logs an error 429 // instead of returning one. 430 func createConvertedDecimalType(length, scale int, logErrors bool) sql.DecimalType { 431 if length > 0 && scale > 0 { 432 dt, err := types.CreateColumnDecimalType(uint8(length), uint8(scale)) 433 if err != nil { 434 if logErrors { 435 logrus.StandardLogger().Errorf("unable to create decimal type with length %d and scale %d: %v", length, scale, err) 436 } 437 return types.InternalDecimalType 438 } 439 return dt 440 } 441 return types.InternalDecimalType 442 } 443 444 // convertHexBlobToDecimalForNumericContext converts byte array value to unsigned int value if originType is BLOB type. 445 // This function is called when convertTo type is number type only. The hex literal values are parsed into blobs as 446 // binary string as default, but for numeric context, the value should be a number. 447 // Byte arrays of other SQL types are not handled here. 448 func convertHexBlobToDecimalForNumericContext(val interface{}, originType sql.Type) (interface{}, error) { 449 if bin, isBinary := val.([]byte); isBinary && types.IsBlobType(originType) { 450 stringVal := hex.EncodeToString(bin) 451 decimalNum, err := strconv.ParseUint(stringVal, 16, 64) 452 if err != nil { 453 return nil, errors.NewKind("failed to convert hex blob value to unsigned int").New() 454 } 455 val = decimalNum 456 } 457 return val, nil 458 }