vitess.io/vitess@v0.16.2/go/vt/vtgate/evalengine/evalengine.go (about)

     1  /*
     2  Copyright 2020 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package evalengine
    18  
    19  import (
    20  	"math"
    21  	"time"
    22  
    23  	"vitess.io/vitess/go/mysql/collations"
    24  	"vitess.io/vitess/go/sqltypes"
    25  	vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
    26  	"vitess.io/vitess/go/vt/sqlparser"
    27  	"vitess.io/vitess/go/vt/vterrors"
    28  	"vitess.io/vitess/go/vt/vtgate/evalengine/internal/decimal"
    29  )
    30  
    31  // Cast converts a Value to the target type.
    32  func Cast(v sqltypes.Value, typ sqltypes.Type) (sqltypes.Value, error) {
    33  	if v.Type() == typ || v.IsNull() {
    34  		return v, nil
    35  	}
    36  	vBytes, err := v.ToBytes()
    37  	if err != nil {
    38  		return v, err
    39  	}
    40  	if sqltypes.IsSigned(typ) && v.IsSigned() {
    41  		return sqltypes.MakeTrusted(typ, vBytes), nil
    42  	}
    43  	if sqltypes.IsUnsigned(typ) && v.IsUnsigned() {
    44  		return sqltypes.MakeTrusted(typ, vBytes), nil
    45  	}
    46  	if (sqltypes.IsFloat(typ) || typ == sqltypes.Decimal) && (v.IsIntegral() || v.IsFloat() || v.Type() == sqltypes.Decimal) {
    47  		return sqltypes.MakeTrusted(typ, vBytes), nil
    48  	}
    49  	if sqltypes.IsQuoted(typ) && (v.IsIntegral() || v.IsFloat() || v.Type() == sqltypes.Decimal || v.IsQuoted()) {
    50  		return sqltypes.MakeTrusted(typ, vBytes), nil
    51  	}
    52  
    53  	// Explicitly disallow Expression.
    54  	if v.Type() == sqltypes.Expression {
    55  		return sqltypes.NULL, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v cannot be cast to %v", v, typ)
    56  	}
    57  
    58  	// If the above fast-paths were not possible,
    59  	// go through full validation.
    60  	return sqltypes.NewValue(typ, vBytes)
    61  }
    62  
    63  // ToUint64 converts Value to uint64.
    64  func ToUint64(v sqltypes.Value) (uint64, error) {
    65  	var num EvalResult
    66  	if err := num.setValueIntegralNumeric(v); err != nil {
    67  		return 0, err
    68  	}
    69  	switch num.typeof() {
    70  	case sqltypes.Int64:
    71  		if num.uint64() > math.MaxInt64 {
    72  			return 0, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "negative number cannot be converted to unsigned: %d", num.int64())
    73  		}
    74  		return num.uint64(), nil
    75  	case sqltypes.Uint64:
    76  		return num.uint64(), nil
    77  	}
    78  	panic("unreachable")
    79  }
    80  
    81  // ToInt64 converts Value to int64.
    82  func ToInt64(v sqltypes.Value) (int64, error) {
    83  	var num EvalResult
    84  	if err := num.setValueIntegralNumeric(v); err != nil {
    85  		return 0, err
    86  	}
    87  	switch num.typeof() {
    88  	case sqltypes.Int64:
    89  		return num.int64(), nil
    90  	case sqltypes.Uint64:
    91  		ival := num.int64()
    92  		if ival < 0 {
    93  			return 0, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unsigned number overflows int64 value: %d", num.uint64())
    94  		}
    95  		return ival, nil
    96  	}
    97  	panic("unreachable")
    98  }
    99  
   100  // ToFloat64 converts Value to float64.
   101  func ToFloat64(v sqltypes.Value) (float64, error) {
   102  	var num EvalResult
   103  	if err := num.setValue(v, collationNumeric); err != nil {
   104  		return 0, err
   105  	}
   106  	num.makeFloat()
   107  	return num.float64(), nil
   108  }
   109  
   110  func LiteralToValue(literal *sqlparser.Literal) (sqltypes.Value, error) {
   111  	lit, err := translateLiteral(literal, nil)
   112  	if err != nil {
   113  		return sqltypes.Value{}, err
   114  	}
   115  	return lit.Val.Value(), nil
   116  }
   117  
   118  // ToNative converts Value to a native go type.
   119  // Decimal is returned as []byte.
   120  func ToNative(v sqltypes.Value) (any, error) {
   121  	var out any
   122  	var err error
   123  	switch {
   124  	case v.Type() == sqltypes.Null:
   125  		// no-op
   126  	case v.IsSigned():
   127  		return ToInt64(v)
   128  	case v.IsUnsigned():
   129  		return ToUint64(v)
   130  	case v.IsFloat():
   131  		return ToFloat64(v)
   132  	case v.IsQuoted() || v.Type() == sqltypes.Bit || v.Type() == sqltypes.Decimal:
   133  		out, err = v.ToBytes()
   134  	case v.Type() == sqltypes.Expression:
   135  		err = vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v cannot be converted to a go type", v)
   136  	}
   137  	return out, err
   138  }
   139  
   140  func compareNumeric(v1, v2 *EvalResult) (int, error) {
   141  	// upcast all <64 bit numeric types to 64 bit, e.g. int8 -> int64, uint8 -> uint64, float32 -> float64
   142  	// so we don't have to consider integer types which aren't 64 bit
   143  	v1.upcastNumeric()
   144  	v2.upcastNumeric()
   145  
   146  	// Equalize the types the same way MySQL does
   147  	// https://dev.mysql.com/doc/refman/8.0/en/type-conversion.html
   148  	switch v1.typeof() {
   149  	case sqltypes.Int64:
   150  		switch v2.typeof() {
   151  		case sqltypes.Uint64:
   152  			if v1.uint64() > math.MaxInt64 {
   153  				return -1, nil
   154  			}
   155  			v1.setUint64(v1.uint64())
   156  		case sqltypes.Float64:
   157  			v1.setFloat(float64(v1.int64()))
   158  		case sqltypes.Decimal:
   159  			v1.setDecimal(decimal.NewFromInt(v1.int64()), 0)
   160  		}
   161  	case sqltypes.Uint64:
   162  		switch v2.typeof() {
   163  		case sqltypes.Int64:
   164  			if v2.uint64() > math.MaxInt64 {
   165  				return 1, nil
   166  			}
   167  			v2.setUint64(v2.uint64())
   168  		case sqltypes.Float64:
   169  			v1.setFloat(float64(v1.uint64()))
   170  		case sqltypes.Decimal:
   171  			v1.setDecimal(decimal.NewFromUint(v1.uint64()), 0)
   172  		}
   173  	case sqltypes.Float64:
   174  		switch v2.typeof() {
   175  		case sqltypes.Int64:
   176  			v2.setFloat(float64(v2.int64()))
   177  		case sqltypes.Uint64:
   178  			if v1.float64() < 0 {
   179  				return -1, nil
   180  			}
   181  			v2.setFloat(float64(v2.uint64()))
   182  		case sqltypes.Decimal:
   183  			f, ok := v2.decimal().Float64()
   184  			if !ok {
   185  				return 0, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.DataOutOfRange, "DECIMAL value is out of range")
   186  			}
   187  			v2.setFloat(f)
   188  		}
   189  	case sqltypes.Decimal:
   190  		switch v2.typeof() {
   191  		case sqltypes.Int64:
   192  			v2.setDecimal(decimal.NewFromInt(v2.int64()), 0)
   193  		case sqltypes.Uint64:
   194  			v2.setDecimal(decimal.NewFromUint(v2.uint64()), 0)
   195  		case sqltypes.Float64:
   196  			f, ok := v1.decimal().Float64()
   197  			if !ok {
   198  				return 0, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.DataOutOfRange, "DECIMAL value is out of range")
   199  			}
   200  			v1.setFloat(f)
   201  		}
   202  	}
   203  
   204  	// Both values are of the same type.
   205  	switch v1.typeof() {
   206  	case sqltypes.Int64:
   207  		v1v, v2v := v1.int64(), v2.int64()
   208  		switch {
   209  		case v1v == v2v:
   210  			return 0, nil
   211  		case v1v < v2v:
   212  			return -1, nil
   213  		}
   214  	case sqltypes.Uint64:
   215  		switch {
   216  		case v1.uint64() == v2.uint64():
   217  			return 0, nil
   218  		case v1.uint64() < v2.uint64():
   219  			return -1, nil
   220  		}
   221  	case sqltypes.Float64:
   222  		v1v, v2v := v1.float64(), v2.float64()
   223  		switch {
   224  		case v1v == v2v:
   225  			return 0, nil
   226  		case v1v < v2v:
   227  			return -1, nil
   228  		}
   229  	case sqltypes.Decimal:
   230  		return v1.decimal().Cmp(v2.decimal()), nil
   231  	}
   232  	return 1, nil
   233  }
   234  
   235  func parseDate(expr *EvalResult) (t time.Time, err error) {
   236  	switch expr.typeof() {
   237  	case sqltypes.Date:
   238  		t, err = sqlparser.ParseDate(expr.string())
   239  	case sqltypes.Timestamp, sqltypes.Datetime:
   240  		t, err = sqlparser.ParseDateTime(expr.string())
   241  	case sqltypes.Time:
   242  		t, err = sqlparser.ParseTime(expr.string())
   243  	}
   244  	return
   245  }
   246  
   247  // matchExprWithAnyDateFormat formats the given expr (usually a string) to a date using the first format
   248  // that does not return an error.
   249  func matchExprWithAnyDateFormat(expr *EvalResult) (t time.Time, err error) {
   250  	t, err = sqlparser.ParseDate(expr.string())
   251  	if err == nil {
   252  		return
   253  	}
   254  	t, err = sqlparser.ParseDateTime(expr.string())
   255  	if err == nil {
   256  		return
   257  	}
   258  	t, err = sqlparser.ParseTime(expr.string())
   259  	return
   260  }
   261  
   262  // Date comparison based on:
   263  //   - https://dev.mysql.com/doc/refman/8.0/en/type-conversion.html
   264  //   - https://dev.mysql.com/doc/refman/8.0/en/date-and-time-type-conversion.html
   265  func compareDates(l, r *EvalResult) (int, error) {
   266  	lTime, err := parseDate(l)
   267  	if err != nil {
   268  		return 0, err
   269  	}
   270  	rTime, err := parseDate(r)
   271  	if err != nil {
   272  		return 0, err
   273  	}
   274  
   275  	return compareGoTimes(lTime, rTime)
   276  }
   277  
   278  func compareDateAndString(l, r *EvalResult) (int, error) {
   279  	var lTime, rTime time.Time
   280  	var err error
   281  	switch {
   282  	case sqltypes.IsDate(l.typeof()):
   283  		lTime, err = parseDate(l)
   284  		if err != nil {
   285  			return 0, err
   286  		}
   287  		rTime, err = matchExprWithAnyDateFormat(r)
   288  		if err != nil {
   289  			return 0, err
   290  		}
   291  	case l.isTextual():
   292  		rTime, err = parseDate(r)
   293  		if err != nil {
   294  			return 0, err
   295  		}
   296  		lTime, err = matchExprWithAnyDateFormat(l)
   297  		if err != nil {
   298  			return 0, err
   299  		}
   300  	}
   301  	return compareGoTimes(lTime, rTime)
   302  }
   303  
   304  func compareGoTimes(lTime, rTime time.Time) (int, error) {
   305  	if lTime.Before(rTime) {
   306  		return -1, nil
   307  	}
   308  	if lTime.After(rTime) {
   309  		return 1, nil
   310  	}
   311  	return 0, nil
   312  }
   313  
   314  // More on string collations coercibility on MySQL documentation:
   315  //   - https://dev.mysql.com/doc/refman/8.0/en/charset-collation-coercibility.html
   316  func compareStrings(l, r *EvalResult) int {
   317  	coll, err := mergeCollations(l, r)
   318  	if err != nil {
   319  		throwEvalError(err)
   320  	}
   321  	collation := collations.Local().LookupByID(coll)
   322  	if collation == nil {
   323  		panic("unknown collation after coercion")
   324  	}
   325  	return collation.Collate(l.bytes(), r.bytes(), false)
   326  }