github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/convert.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package expression
    16  
    17  import (
    18  	"encoding/hex"
    19  	"fmt"
    20  	"strconv"
    21  	"strings"
    22  	"time"
    23  
    24  	"github.com/sirupsen/logrus"
    25  	"gopkg.in/src-d/go-errors.v1"
    26  
    27  	"github.com/dolthub/go-mysql-server/sql"
    28  	"github.com/dolthub/go-mysql-server/sql/types"
    29  )
    30  
    31  // ErrConvertExpression is returned when a conversion is not possible.
    32  var ErrConvertExpression = errors.NewKind("expression '%v': couldn't convert to %v")
    33  
    34  const (
    35  	// ConvertToBinary is a conversion to binary.
    36  	ConvertToBinary = "binary"
    37  	// ConvertToChar is a conversion to char.
    38  	ConvertToChar = "char"
    39  	// ConvertToNChar is a conversion to nchar.
    40  	ConvertToNChar = "nchar"
    41  	// ConvertToDate is a conversion to date.
    42  	ConvertToDate = "date"
    43  	// ConvertToDatetime is a conversion to datetime.
    44  	ConvertToDatetime = "datetime"
    45  	// ConvertToDecimal is a conversion to decimal.
    46  	ConvertToDecimal = "decimal"
    47  	// ConvertToFloat is a conversion to float.
    48  	ConvertToFloat = "float"
    49  	// ConvertToDouble is a conversion to double.
    50  	ConvertToDouble = "double"
    51  	// ConvertToJSON is a conversion to json.
    52  	ConvertToJSON = "json"
    53  	// ConvertToReal is a conversion to double.
    54  	ConvertToReal = "real"
    55  	// ConvertToSigned is a conversion to signed.
    56  	ConvertToSigned = "signed"
    57  	// ConvertToTime is a conversion to time.
    58  	ConvertToTime = "time"
    59  	// ConvertToUnsigned is a conversion to unsigned.
    60  	ConvertToUnsigned = "unsigned"
    61  )
    62  
    63  // Convert represent a CAST(x AS T) or CONVERT(x, T) operation that casts x expression to type T.
    64  type Convert struct {
    65  	UnaryExpression
    66  	// castToType is a string representation of the base type to which we are casting (e.g. "char", "float", "decimal")
    67  	castToType string
    68  	// typeLength is the optional length parameter for types that support it (e.g. "char(10)")
    69  	typeLength int
    70  	// typeScale is the optional scale parameter for types that support it (e.g. "decimal(10, 2)")
    71  	typeScale int
    72  	// cachedDecimalType is the cached Decimal type for this convert expression. Because new Decimal types
    73  	// must be created with their specific scale and precision values, unlike other types, we cache the created
    74  	// type to avoid re-creating it on every call to Type().
    75  	cachedDecimalType sql.DecimalType
    76  }
    77  
    78  var _ sql.Expression = (*Convert)(nil)
    79  var _ sql.CollationCoercible = (*Convert)(nil)
    80  
    81  // NewConvert creates a new Convert expression that will attempt to convert the specified expression |expr| into the
    82  // |castToType| type. All optional parameters (i.e. typeLength, typeScale, and charset) are omitted and initialized
    83  // to their zero values.
    84  func NewConvert(expr sql.Expression, castToType string) *Convert {
    85  	disableRounding(expr)
    86  	return &Convert{
    87  		UnaryExpression: UnaryExpression{Child: expr},
    88  		castToType:      strings.ToLower(castToType),
    89  	}
    90  }
    91  
    92  // NewConvertWithLengthAndScale creates a new Convert expression that will attempt to convert |expr| into the
    93  // |castToType| type, with |typeLength| specifying a length constraint of the converted type, and |typeScale| specifying
    94  // a scale constraint of the converted type.
    95  func NewConvertWithLengthAndScale(expr sql.Expression, castToType string, typeLength, typeScale int) *Convert {
    96  	disableRounding(expr)
    97  	return &Convert{
    98  		UnaryExpression: UnaryExpression{Child: expr},
    99  		castToType:      strings.ToLower(castToType),
   100  		typeLength:      typeLength,
   101  		typeScale:       typeScale,
   102  	}
   103  }
   104  
   105  // GetConvertToType returns which type the both left and right values should be converted to.
   106  // If neither sql.Type represent number, then converted to string. Otherwise, we try to get
   107  // the appropriate type to avoid any precision loss.
   108  func GetConvertToType(l, r sql.Type) string {
   109  	if types.Null == l {
   110  		return GetConvertToType(r, r)
   111  	}
   112  	if types.Null == r {
   113  		return GetConvertToType(l, l)
   114  	}
   115  
   116  	if !types.IsNumber(l) || !types.IsNumber(r) {
   117  		return ConvertToChar
   118  	}
   119  
   120  	if types.IsDecimal(l) || types.IsDecimal(r) {
   121  		return ConvertToDecimal
   122  	}
   123  	if types.IsUnsigned(l) && types.IsUnsigned(r) {
   124  		return ConvertToUnsigned
   125  	}
   126  	if types.IsSigned(l) && types.IsSigned(r) {
   127  		return ConvertToSigned
   128  	}
   129  	if types.IsInteger(l) && types.IsInteger(r) {
   130  		return ConvertToSigned
   131  	}
   132  
   133  	return ConvertToChar
   134  }
   135  
   136  // IsNullable implements the Expression interface.
   137  func (c *Convert) IsNullable() bool {
   138  	switch c.castToType {
   139  	case ConvertToDate, ConvertToDatetime:
   140  		return true
   141  	default:
   142  		return c.Child.IsNullable()
   143  	}
   144  }
   145  
   146  // Type implements the Expression interface.
   147  func (c *Convert) Type() sql.Type {
   148  	switch c.castToType {
   149  	case ConvertToBinary:
   150  		return types.LongBlob
   151  	case ConvertToChar, ConvertToNChar:
   152  		return types.LongText
   153  	case ConvertToDate:
   154  		return types.Date
   155  	case ConvertToDatetime:
   156  		return types.DatetimeMaxPrecision
   157  	case ConvertToDecimal:
   158  		if c.cachedDecimalType == nil {
   159  			c.cachedDecimalType = createConvertedDecimalType(c.typeLength, c.typeScale, true)
   160  		}
   161  		return c.cachedDecimalType
   162  	case ConvertToFloat:
   163  		return types.Float32
   164  	case ConvertToDouble, ConvertToReal:
   165  		return types.Float64
   166  	case ConvertToJSON:
   167  		return types.JSON
   168  	case ConvertToSigned:
   169  		return types.Int64
   170  	case ConvertToTime:
   171  		return types.Time
   172  	case ConvertToUnsigned:
   173  		return types.Uint64
   174  	default:
   175  		return types.Null
   176  	}
   177  }
   178  
   179  // CollationCoercibility implements the interface sql.CollationCoercible.
   180  func (c *Convert) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   181  	switch c.castToType {
   182  	case ConvertToBinary:
   183  		return sql.Collation_binary, 2
   184  	case ConvertToChar, ConvertToNChar:
   185  		return ctx.GetCollation(), 2
   186  	case ConvertToDate:
   187  		return sql.Collation_binary, 5
   188  	case ConvertToDatetime:
   189  		return sql.Collation_binary, 5
   190  	case ConvertToDecimal:
   191  		return sql.Collation_binary, 5
   192  	case ConvertToDouble, ConvertToReal, ConvertToFloat:
   193  		return sql.Collation_binary, 5
   194  	case ConvertToJSON:
   195  		return ctx.GetCharacterSet().BinaryCollation(), 2
   196  	case ConvertToSigned:
   197  		return sql.Collation_binary, 5
   198  	case ConvertToTime:
   199  		return sql.Collation_binary, 5
   200  	case ConvertToUnsigned:
   201  		return sql.Collation_binary, 5
   202  	default:
   203  		return sql.Collation_binary, 7
   204  	}
   205  }
   206  
   207  // String implements the Stringer interface.
   208  func (c *Convert) String() string {
   209  	extraTypeInfo := ""
   210  	if c.typeLength > 0 {
   211  		if c.typeScale > 0 {
   212  			extraTypeInfo = fmt.Sprintf("(%d,%d)", c.typeLength, c.typeScale)
   213  		} else {
   214  			extraTypeInfo = fmt.Sprintf("(%d)", c.typeLength)
   215  		}
   216  	}
   217  	return fmt.Sprintf("convert(%v, %v%s)", c.Child, c.castToType, extraTypeInfo)
   218  }
   219  
   220  // DebugString implements the Expression interface.
   221  func (c *Convert) DebugString() string {
   222  	pr := sql.NewTreePrinter()
   223  	_ = pr.WriteNode("convert")
   224  	children := []string{
   225  		fmt.Sprintf("type: %v", c.castToType),
   226  	}
   227  
   228  	if c.typeLength > 0 {
   229  		children = append(children, fmt.Sprintf("typeLength: %v", c.typeLength))
   230  	}
   231  
   232  	if c.typeScale > 0 {
   233  		children = append(children, fmt.Sprintf("typeScale: %v", c.typeScale))
   234  	}
   235  
   236  	children = append(children, fmt.Sprintf(sql.DebugString(c.Child)))
   237  
   238  	_ = pr.WriteChildren(children...)
   239  	return pr.String()
   240  }
   241  
   242  // WithChildren implements the Expression interface.
   243  func (c *Convert) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   244  	if len(children) != 1 {
   245  		return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), 1)
   246  	}
   247  	return NewConvertWithLengthAndScale(children[0], c.castToType, c.typeLength, c.typeScale), nil
   248  }
   249  
   250  // Eval implements the Expression interface.
   251  func (c *Convert) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
   252  	val, err := c.Child.Eval(ctx, row)
   253  	if err != nil {
   254  		return nil, err
   255  	}
   256  
   257  	if val == nil {
   258  		return nil, nil
   259  	}
   260  
   261  	// Should always return nil, and a warning instead
   262  	casted, err := convertValue(val, c.castToType, c.Child.Type(), c.typeLength, c.typeScale)
   263  	if err != nil {
   264  		if c.castToType == ConvertToJSON {
   265  			return nil, ErrConvertExpression.Wrap(err, c.String(), c.castToType)
   266  		}
   267  		ctx.Warn(1292, "Incorrect %s value: %v", c.castToType, val)
   268  		return nil, nil
   269  	}
   270  
   271  	return casted, nil
   272  }
   273  
   274  // convertValue only returns an error if converting to JSON, Date, and Datetime;
   275  // the zero value is returned for float types. Nil is returned in all other cases.
   276  // If |typeLength| and |typeScale| are 0, they are ignored, otherwise they are used as constraints on the
   277  // converted type where applicable (e.g. Char conversion supports only |typeLength|, Decimal conversion supports
   278  // |typeLength| and |typeScale|).
   279  func convertValue(val interface{}, castTo string, originType sql.Type, typeLength, typeScale int) (interface{}, error) {
   280  	switch strings.ToLower(castTo) {
   281  	case ConvertToBinary:
   282  		b, _, err := types.LongBlob.Convert(val)
   283  		if err != nil {
   284  			return nil, nil
   285  		}
   286  		if types.IsTextOnly(originType) {
   287  			// For string types we need to re-encode the string as we want the binary representation of the character set
   288  			encoder := originType.(sql.StringType).Collation().CharacterSet().Encoder()
   289  			encodedBytes, ok := encoder.Encode(b.([]byte))
   290  			if !ok {
   291  				return nil, fmt.Errorf("unable to re-encode string to convert to binary")
   292  			}
   293  			b = encodedBytes
   294  		}
   295  		return truncateConvertedValue(b, typeLength)
   296  	case ConvertToChar, ConvertToNChar:
   297  		s, _, err := types.LongText.Convert(val)
   298  		if err != nil {
   299  			return nil, nil
   300  		}
   301  		return truncateConvertedValue(s, typeLength)
   302  	case ConvertToDate:
   303  		_, isTime := val.(time.Time)
   304  		_, isString := val.(string)
   305  		_, isBinary := val.([]byte)
   306  		if !(isTime || isString || isBinary) {
   307  			return nil, nil
   308  		}
   309  		d, _, err := types.Date.Convert(val)
   310  		if err != nil {
   311  			return nil, err
   312  		}
   313  		return d, nil
   314  	case ConvertToDatetime:
   315  		_, isTime := val.(time.Time)
   316  		_, isString := val.(string)
   317  		_, isBinary := val.([]byte)
   318  		if !(isTime || isString || isBinary) {
   319  			return nil, nil
   320  		}
   321  		d, _, err := types.DatetimeMaxPrecision.Convert(val)
   322  		if err != nil {
   323  			return nil, err
   324  		}
   325  		return d, nil
   326  	case ConvertToDecimal:
   327  		value, err := convertHexBlobToDecimalForNumericContext(val, originType)
   328  		if err != nil {
   329  			return nil, err
   330  		}
   331  		dt := createConvertedDecimalType(typeLength, typeScale, false)
   332  		d, _, err := dt.Convert(value)
   333  		if err != nil {
   334  			return "0", nil
   335  		}
   336  		return d, nil
   337  	case ConvertToFloat:
   338  		value, err := convertHexBlobToDecimalForNumericContext(val, originType)
   339  		if err != nil {
   340  			return nil, err
   341  		}
   342  		d, _, err := types.Float32.Convert(value)
   343  		if err != nil {
   344  			return types.Float32.Zero(), nil
   345  		}
   346  		return d, nil
   347  	case ConvertToDouble, ConvertToReal:
   348  		value, err := convertHexBlobToDecimalForNumericContext(val, originType)
   349  		if err != nil {
   350  			return nil, err
   351  		}
   352  		d, _, err := types.Float64.Convert(value)
   353  		if err != nil {
   354  			return types.Float64.Zero(), nil
   355  		}
   356  		return d, nil
   357  	case ConvertToJSON:
   358  		js, _, err := types.JSON.Convert(val)
   359  		if err != nil {
   360  			return nil, err
   361  		}
   362  		return js, nil
   363  	case ConvertToSigned:
   364  		value, err := convertHexBlobToDecimalForNumericContext(val, originType)
   365  		if err != nil {
   366  			return nil, err
   367  		}
   368  		num, _, err := types.Int64.Convert(value)
   369  		if err != nil {
   370  			return types.Int64.Zero(), nil
   371  		}
   372  
   373  		return num, nil
   374  	case ConvertToTime:
   375  		t, _, err := types.Time.Convert(val)
   376  		if err != nil {
   377  			return nil, nil
   378  		}
   379  		return t, nil
   380  	case ConvertToUnsigned:
   381  		value, err := convertHexBlobToDecimalForNumericContext(val, originType)
   382  		if err != nil {
   383  			return nil, err
   384  		}
   385  		num, _, err := types.Uint64.Convert(value)
   386  		if err != nil {
   387  			num, _, err = types.Int64.Convert(value)
   388  			if err != nil {
   389  				return types.Uint64.Zero(), nil
   390  			}
   391  			return uint64(num.(int64)), nil
   392  		}
   393  		return num, nil
   394  	default:
   395  		return nil, nil
   396  	}
   397  }
   398  
   399  // truncateConvertedValue truncates |val| to the specified |typeLength| if |val|
   400  // is a string or byte slice. If the typeLength is 0, or if it is greater than
   401  // the length of |val|, then |val| is simply returned as is. If |val| is not a
   402  // string or []byte, then an error is returned.
   403  func truncateConvertedValue(val interface{}, typeLength int) (interface{}, error) {
   404  	if typeLength <= 0 {
   405  		return val, nil
   406  	}
   407  
   408  	switch v := val.(type) {
   409  	case []byte:
   410  		if len(v) <= typeLength {
   411  			typeLength = len(v)
   412  		}
   413  		return v[:typeLength], nil
   414  	case string:
   415  		if len(v) <= typeLength {
   416  			typeLength = len(v)
   417  		}
   418  		return v[:typeLength], nil
   419  	default:
   420  		return nil, fmt.Errorf("unsupported type for truncation: %T", val)
   421  	}
   422  }
   423  
   424  // createConvertedDecimalType creates a new Decimal type with the specified |precision| and |scale|. If a Decimal
   425  // type cannot be created from the values specified, the internal Decimal type is returned. If |logErrors| is true,
   426  // an error will also logged to the standard logger. (Setting |logErrors| to false, allows the caller to prevent
   427  // spurious error message from being logged multiple times for the same error.) This function is intended to be
   428  // used in places where an error cannot be returned (e.g. Node.Type() implementations), hence why it logs an error
   429  // instead of returning one.
   430  func createConvertedDecimalType(length, scale int, logErrors bool) sql.DecimalType {
   431  	if length > 0 && scale > 0 {
   432  		dt, err := types.CreateColumnDecimalType(uint8(length), uint8(scale))
   433  		if err != nil {
   434  			if logErrors {
   435  				logrus.StandardLogger().Errorf("unable to create decimal type with length %d and scale %d: %v", length, scale, err)
   436  			}
   437  			return types.InternalDecimalType
   438  		}
   439  		return dt
   440  	}
   441  	return types.InternalDecimalType
   442  }
   443  
   444  // convertHexBlobToDecimalForNumericContext converts byte array value to unsigned int value if originType is BLOB type.
   445  // This function is called when convertTo type is number type only. The hex literal values are parsed into blobs as
   446  // binary string as default, but for numeric context, the value should be a number.
   447  // Byte arrays of other SQL types are not handled here.
   448  func convertHexBlobToDecimalForNumericContext(val interface{}, originType sql.Type) (interface{}, error) {
   449  	if bin, isBinary := val.([]byte); isBinary && types.IsBlobType(originType) {
   450  		stringVal := hex.EncodeToString(bin)
   451  		decimalNum, err := strconv.ParseUint(stringVal, 16, 64)
   452  		if err != nil {
   453  			return nil, errors.NewKind("failed to convert hex blob value to unsigned int").New()
   454  		}
   455  		val = decimalNum
   456  	}
   457  	return val, nil
   458  }