vitess.io/vitess@v0.16.2/go/vt/vtgate/evalengine/arithmetic.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package evalengine
    18  
    19  import (
    20  	"bytes"
    21  	"strconv"
    22  	"strings"
    23  
    24  	"vitess.io/vitess/go/hack"
    25  	"vitess.io/vitess/go/sqltypes"
    26  	vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
    27  	"vitess.io/vitess/go/vt/vterrors"
    28  	"vitess.io/vitess/go/vt/vtgate/evalengine/internal/decimal"
    29  )
    30  
    31  // evalengine represents a numeric value extracted from
    32  // a Value, used for arithmetic operations.
    33  var zeroBytes = []byte("0")
    34  
    35  func dataOutOfRangeError(v1, v2 any, typ, sign string) error {
    36  	return vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.DataOutOfRange, "%s value is out of range in '(%v %s %v)'", typ, v1, sign, v2)
    37  }
    38  
    39  // FormatFloat formats a float64 as a byte string in a similar way to what MySQL does
    40  func FormatFloat(typ sqltypes.Type, f float64) []byte {
    41  	return AppendFloat(nil, typ, f)
    42  }
    43  
    44  func AppendFloat(buf []byte, typ sqltypes.Type, f float64) []byte {
    45  	format := byte('g')
    46  	if typ == sqltypes.Decimal {
    47  		format = 'f'
    48  	}
    49  
    50  	// the float printer in MySQL does not add a positive sign before
    51  	// the exponent for positive exponents, but the Golang printer does
    52  	// do that, and there's no way to customize it, so we must strip the
    53  	// redundant positive sign manually
    54  	// e.g. 1.234E+56789 -> 1.234E56789
    55  	fstr := strconv.AppendFloat(buf, f, format, -1, 64)
    56  	if idx := bytes.IndexByte(fstr, 'e'); idx >= 0 {
    57  		if fstr[idx+1] == '+' {
    58  			fstr = append(fstr[:idx+1], fstr[idx+2:]...)
    59  		}
    60  	}
    61  
    62  	return fstr
    63  }
    64  
    65  // Add adds two values together
    66  // if v1 or v2 is null, then it returns null
    67  func Add(v1, v2 sqltypes.Value) (sqltypes.Value, error) {
    68  	if v1.IsNull() || v2.IsNull() {
    69  		return sqltypes.NULL, nil
    70  	}
    71  
    72  	var lv1, lv2, out EvalResult
    73  	if err := lv1.setValue(v1, collationNumeric); err != nil {
    74  		return sqltypes.NULL, err
    75  	}
    76  	if err := lv2.setValue(v2, collationNumeric); err != nil {
    77  		return sqltypes.NULL, err
    78  	}
    79  
    80  	err := addNumericWithError(&lv1, &lv2, &out)
    81  	if err != nil {
    82  		return sqltypes.NULL, err
    83  	}
    84  	return out.Value(), nil
    85  }
    86  
    87  // Subtract takes two values and subtracts them
    88  func Subtract(v1, v2 sqltypes.Value) (sqltypes.Value, error) {
    89  	if v1.IsNull() || v2.IsNull() {
    90  		return sqltypes.NULL, nil
    91  	}
    92  
    93  	var lv1, lv2, out EvalResult
    94  	if err := lv1.setValue(v1, collationNumeric); err != nil {
    95  		return sqltypes.NULL, err
    96  	}
    97  	if err := lv2.setValue(v2, collationNumeric); err != nil {
    98  		return sqltypes.NULL, err
    99  	}
   100  
   101  	err := subtractNumericWithError(&lv1, &lv2, &out)
   102  	if err != nil {
   103  		return sqltypes.NULL, err
   104  	}
   105  
   106  	return out.Value(), nil
   107  }
   108  
   109  // Multiply takes two values and multiplies it together
   110  func Multiply(v1, v2 sqltypes.Value) (sqltypes.Value, error) {
   111  	if v1.IsNull() || v2.IsNull() {
   112  		return sqltypes.NULL, nil
   113  	}
   114  
   115  	var lv1, lv2, out EvalResult
   116  	if err := lv1.setValue(v1, collationNumeric); err != nil {
   117  		return sqltypes.NULL, err
   118  	}
   119  	if err := lv2.setValue(v2, collationNumeric); err != nil {
   120  		return sqltypes.NULL, err
   121  	}
   122  
   123  	err := multiplyNumericWithError(&lv1, &lv2, &out)
   124  	if err != nil {
   125  		return sqltypes.NULL, err
   126  	}
   127  
   128  	return out.Value(), nil
   129  }
   130  
   131  // Divide (Float) for MySQL. Replicates behavior of "/" operator
   132  func Divide(v1, v2 sqltypes.Value) (sqltypes.Value, error) {
   133  	if v1.IsNull() || v2.IsNull() {
   134  		return sqltypes.NULL, nil
   135  	}
   136  
   137  	var lv1, lv2, out EvalResult
   138  	if err := lv1.setValue(v1, collationNumeric); err != nil {
   139  		return sqltypes.NULL, err
   140  	}
   141  	if err := lv2.setValue(v2, collationNumeric); err != nil {
   142  		return sqltypes.NULL, err
   143  	}
   144  
   145  	err := divideNumericWithError(&lv1, &lv2, true, &out)
   146  	if err != nil {
   147  		return sqltypes.NULL, err
   148  	}
   149  
   150  	return out.Value(), nil
   151  }
   152  
   153  // NullSafeAdd adds two Values in a null-safe manner. A null value
   154  // is treated as 0. If both values are null, then a null is returned.
   155  // If both values are not null, a numeric value is built
   156  // from each input: Signed->int64, Unsigned->uint64, Float->float64.
   157  // Otherwise the 'best type fit' is chosen for the number: int64 or float64.
   158  // OpAddition is performed by upgrading types as needed, or in case
   159  // of overflow: int64->uint64, int64->float64, uint64->float64.
   160  // Unsigned ints can only be added to positive ints. After the
   161  // addition, if one of the input types was Decimal, then
   162  // a Decimal is built. Otherwise, the final type of the
   163  // result is preserved.
   164  func NullSafeAdd(v1, v2 sqltypes.Value, resultType sqltypes.Type) (sqltypes.Value, error) {
   165  	if v1.IsNull() {
   166  		v1 = sqltypes.MakeTrusted(resultType, zeroBytes)
   167  	}
   168  	if v2.IsNull() {
   169  		v2 = sqltypes.MakeTrusted(resultType, zeroBytes)
   170  	}
   171  
   172  	var lv1, lv2, out EvalResult
   173  	if err := lv1.setValue(v1, collationNumeric); err != nil {
   174  		return sqltypes.NULL, err
   175  	}
   176  	if err := lv2.setValue(v2, collationNumeric); err != nil {
   177  		return sqltypes.NULL, err
   178  	}
   179  
   180  	err := addNumericWithError(&lv1, &lv2, &out)
   181  	if err != nil {
   182  		return sqltypes.NULL, err
   183  	}
   184  	return out.toSQLValue(resultType), nil
   185  }
   186  
   187  func addNumericWithError(v1, v2, out *EvalResult) error {
   188  	v1, v2 = makeNumericAndPrioritize(v1, v2)
   189  	switch v1.typeof() {
   190  	case sqltypes.Int64:
   191  		return intPlusIntWithError(v1.uint64(), v2.uint64(), out)
   192  	case sqltypes.Uint64:
   193  		switch v2.typeof() {
   194  		case sqltypes.Int64:
   195  			return uintPlusIntWithError(v1.uint64(), v2.uint64(), out)
   196  		case sqltypes.Uint64:
   197  			return uintPlusUintWithError(v1.uint64(), v2.uint64(), out)
   198  		}
   199  	case sqltypes.Decimal:
   200  		decimalPlusAny(v1.decimal(), v1.length_, v2, out)
   201  		return nil
   202  	case sqltypes.Float64:
   203  		return floatPlusAny(v1.float64(), v2, out)
   204  	}
   205  	return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid arithmetic between: %s %s", v1.value().String(), v2.value().String())
   206  }
   207  
   208  func subtractNumericWithError(v1, v2, out *EvalResult) error {
   209  	v1.makeNumeric()
   210  	v2.makeNumeric()
   211  	switch v1.typeof() {
   212  	case sqltypes.Int64:
   213  		switch v2.typeof() {
   214  		case sqltypes.Int64:
   215  			return intMinusIntWithError(v1.uint64(), v2.uint64(), out)
   216  		case sqltypes.Uint64:
   217  			return intMinusUintWithError(v1.uint64(), v2.uint64(), out)
   218  		case sqltypes.Float64:
   219  			return anyMinusFloat(v1, v2.float64(), out)
   220  		case sqltypes.Decimal:
   221  			anyMinusDecimal(v1, v2.decimal(), v2.length_, out)
   222  			return nil
   223  		}
   224  	case sqltypes.Uint64:
   225  		switch v2.typeof() {
   226  		case sqltypes.Int64:
   227  			return uintMinusIntWithError(v1.uint64(), v2.uint64(), out)
   228  		case sqltypes.Uint64:
   229  			return uintMinusUintWithError(v1.uint64(), v2.uint64(), out)
   230  		case sqltypes.Float64:
   231  			return anyMinusFloat(v1, v2.float64(), out)
   232  		case sqltypes.Decimal:
   233  			anyMinusDecimal(v1, v2.decimal(), v2.length_, out)
   234  			return nil
   235  		}
   236  	case sqltypes.Float64:
   237  		return floatMinusAny(v1.float64(), v2, out)
   238  	case sqltypes.Decimal:
   239  		switch v2.typeof() {
   240  		case sqltypes.Float64:
   241  			return anyMinusFloat(v1, v2.float64(), out)
   242  		default:
   243  			decimalMinusAny(v1.decimal(), v1.length_, v2, out)
   244  			return nil
   245  		}
   246  	}
   247  	return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid arithmetic between: %s %s", v1.value().String(), v2.value().String())
   248  }
   249  
   250  func multiplyNumericWithError(v1, v2, out *EvalResult) error {
   251  	v1, v2 = makeNumericAndPrioritize(v1, v2)
   252  	switch v1.typeof() {
   253  	case sqltypes.Int64:
   254  		return intTimesIntWithError(v1.uint64(), v2.uint64(), out)
   255  	case sqltypes.Uint64:
   256  		switch v2.typeof() {
   257  		case sqltypes.Int64:
   258  			return uintTimesIntWithError(v1.uint64(), v2.uint64(), out)
   259  		case sqltypes.Uint64:
   260  			return uintTimesUintWithError(v1.uint64(), v2.uint64(), out)
   261  		}
   262  	case sqltypes.Float64:
   263  		return floatTimesAny(v1.float64(), v2, out)
   264  	case sqltypes.Decimal:
   265  		decimalTimesAny(v1.decimal(), v1.length_, v2, out)
   266  		return nil
   267  	}
   268  	return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid arithmetic between: %s %s", v1.value().String(), v2.value().String())
   269  
   270  }
   271  
   272  func divideNumericWithError(v1, v2 *EvalResult, precise bool, out *EvalResult) error {
   273  	v1.makeNumeric()
   274  	v2.makeNumeric()
   275  	if !precise && v1.typeof() != sqltypes.Decimal && v2.typeof() != sqltypes.Decimal {
   276  		switch v1.typeof() {
   277  		case sqltypes.Int64:
   278  			return floatDivideAnyWithError(float64(v1.int64()), v2, out)
   279  
   280  		case sqltypes.Uint64:
   281  			return floatDivideAnyWithError(float64(v1.uint64()), v2, out)
   282  
   283  		case sqltypes.Float64:
   284  			return floatDivideAnyWithError(v1.float64(), v2, out)
   285  		}
   286  	}
   287  	switch {
   288  	case v1.typeof() == sqltypes.Float64:
   289  		return floatDivideAnyWithError(v1.float64(), v2, out)
   290  	case v2.typeof() == sqltypes.Float64:
   291  		v1f, err := v1.coerceToFloat()
   292  		if err != nil {
   293  			return err
   294  		}
   295  		return floatDivideAnyWithError(v1f, v2, out)
   296  	default:
   297  		decimalDivide(v1, v2, divPrecisionIncrement, out)
   298  		return nil
   299  	}
   300  }
   301  
   302  // makeNumericAndPrioritize reorders the input parameters
   303  // to be Float64, Decimal, Uint64, Int64.
   304  func makeNumericAndPrioritize(i1, i2 *EvalResult) (*EvalResult, *EvalResult) {
   305  	i1.makeNumeric()
   306  	i2.makeNumeric()
   307  	switch i1.typeof() {
   308  	case sqltypes.Int64:
   309  		if i2.typeof() == sqltypes.Uint64 || i2.typeof() == sqltypes.Float64 || i2.typeof() == sqltypes.Decimal {
   310  			return i2, i1
   311  		}
   312  	case sqltypes.Uint64:
   313  		if i2.typeof() == sqltypes.Float64 || i2.typeof() == sqltypes.Decimal {
   314  			return i2, i1
   315  		}
   316  	case sqltypes.Decimal:
   317  		if i2.typeof() == sqltypes.Float64 {
   318  			return i2, i1
   319  		}
   320  	}
   321  	return i1, i2
   322  }
   323  
   324  func intPlusIntWithError(v1u, v2u uint64, out *EvalResult) error {
   325  	v1, v2 := int64(v1u), int64(v2u)
   326  	result := v1 + v2
   327  	if (result > v1) != (v2 > 0) {
   328  		return dataOutOfRangeError(v1, v2, "BIGINT", "+")
   329  	}
   330  	out.setInt64(result)
   331  	return nil
   332  }
   333  
   334  func intMinusIntWithError(v1u, v2u uint64, out *EvalResult) error {
   335  	v1, v2 := int64(v1u), int64(v2u)
   336  	result := v1 - v2
   337  
   338  	if (result < v1) != (v2 > 0) {
   339  		return dataOutOfRangeError(v1, v2, "BIGINT", "-")
   340  	}
   341  	out.setInt64(result)
   342  	return nil
   343  }
   344  
   345  func intTimesIntWithError(v1u, v2u uint64, out *EvalResult) error {
   346  	v1, v2 := int64(v1u), int64(v2u)
   347  	result := v1 * v2
   348  	if v1 != 0 && result/v1 != v2 {
   349  		return dataOutOfRangeError(v1, v2, "BIGINT", "*")
   350  	}
   351  	out.setInt64(result)
   352  	return nil
   353  
   354  }
   355  
   356  func intMinusUintWithError(v1u uint64, v2 uint64, out *EvalResult) error {
   357  	v1 := int64(v1u)
   358  	if v1 < 0 || v1 < int64(v2) {
   359  		return dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "-")
   360  	}
   361  	return uintMinusUintWithError(v1u, v2, out)
   362  }
   363  
   364  func uintPlusIntWithError(v1 uint64, v2u uint64, out *EvalResult) error {
   365  	v2 := int64(v2u)
   366  	result := v1 + uint64(v2)
   367  	if v2 < 0 && v1 < uint64(-v2) || v2 > 0 && (result < v1 || result < uint64(v2)) {
   368  		return dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "+")
   369  	}
   370  	// convert to int -> uint is because for numeric operators (such as + or -)
   371  	// where one of the operands is an unsigned integer, the result is unsigned by default.
   372  	out.setUint64(result)
   373  	return nil
   374  }
   375  
   376  func uintMinusIntWithError(v1 uint64, v2u uint64, out *EvalResult) error {
   377  	v2 := int64(v2u)
   378  	if int64(v1) < v2 && v2 > 0 {
   379  		return dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "-")
   380  	}
   381  	// uint - (- int) = uint + int
   382  	if v2 < 0 {
   383  		return uintPlusIntWithError(v1, uint64(-v2), out)
   384  	}
   385  	return uintMinusUintWithError(v1, uint64(v2), out)
   386  }
   387  
   388  func uintTimesIntWithError(v1 uint64, v2u uint64, out *EvalResult) error {
   389  	v2 := int64(v2u)
   390  	if v1 == 0 || v2 == 0 {
   391  		out.setUint64(0)
   392  		return nil
   393  	}
   394  	if v2 < 0 || int64(v1) < 0 {
   395  		return dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "*")
   396  	}
   397  	return uintTimesUintWithError(v1, uint64(v2), out)
   398  }
   399  
   400  func uintPlusUintWithError(v1, v2 uint64, out *EvalResult) error {
   401  	result := v1 + v2
   402  	if result < v1 || result < v2 {
   403  		return dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "+")
   404  	}
   405  	out.setUint64(result)
   406  	return nil
   407  }
   408  
   409  func uintMinusUintWithError(v1, v2 uint64, out *EvalResult) error {
   410  	result := v1 - v2
   411  	if v2 > v1 {
   412  		return dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "-")
   413  	}
   414  	out.setUint64(result)
   415  	return nil
   416  }
   417  
   418  func uintTimesUintWithError(v1, v2 uint64, out *EvalResult) error {
   419  	if v1 == 0 || v2 == 0 {
   420  		out.setUint64(0)
   421  		return nil
   422  	}
   423  	result := v1 * v2
   424  	if result < v2 || result < v1 {
   425  		return dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "*")
   426  	}
   427  	out.setUint64(result)
   428  	return nil
   429  }
   430  
   431  func floatPlusAny(v1 float64, v2 *EvalResult, out *EvalResult) error {
   432  	v2f, err := v2.coerceToFloat()
   433  	if err != nil {
   434  		return err
   435  	}
   436  	add := v1 + v2f
   437  	out.setFloat(add)
   438  	return nil
   439  }
   440  
   441  func floatMinusAny(v1 float64, v2 *EvalResult, out *EvalResult) error {
   442  	v2f, err := v2.coerceToFloat()
   443  	if err != nil {
   444  		return err
   445  	}
   446  	out.setFloat(v1 - v2f)
   447  	return nil
   448  }
   449  
   450  func floatTimesAny(v1 float64, v2 *EvalResult, out *EvalResult) error {
   451  	v2f, err := v2.coerceToFloat()
   452  	if err != nil {
   453  		return err
   454  	}
   455  	out.setFloat(v1 * v2f)
   456  	return nil
   457  }
   458  
   459  func maxprec(a, b int32) int32 {
   460  	if a > b {
   461  		return a
   462  	}
   463  	return b
   464  }
   465  
   466  func decimalPlusAny(v1 decimal.Decimal, f1 int32, v2 *EvalResult, out *EvalResult) {
   467  	v2d := v2.coerceToDecimal()
   468  	out.setDecimal(v1.Add(v2d), maxprec(f1, v2.length_))
   469  }
   470  
   471  func decimalMinusAny(v1 decimal.Decimal, f1 int32, v2 *EvalResult, out *EvalResult) {
   472  	v2d := v2.coerceToDecimal()
   473  	out.setDecimal(v1.Sub(v2d), maxprec(f1, v2.length_))
   474  }
   475  
   476  func anyMinusDecimal(v1 *EvalResult, v2 decimal.Decimal, f2 int32, out *EvalResult) {
   477  	v1d := v1.coerceToDecimal()
   478  	out.setDecimal(v1d.Sub(v2), maxprec(v1.length_, f2))
   479  }
   480  
   481  func decimalTimesAny(v1 decimal.Decimal, f1 int32, v2 *EvalResult, out *EvalResult) {
   482  	v2d := v2.coerceToDecimal()
   483  	out.setDecimal(v1.Mul(v2d), maxprec(f1, v2.length_))
   484  }
   485  
   486  const divPrecisionIncrement = 4
   487  
   488  func decimalDivide(v1, v2 *EvalResult, incrPrecision int32, out *EvalResult) {
   489  	v1d := v1.coerceToDecimal()
   490  	v2d := v2.coerceToDecimal()
   491  	if v2d.IsZero() {
   492  		out.setNull()
   493  		return
   494  	}
   495  	out.setDecimal(v1d.Div(v2d, incrPrecision), v1.length_+incrPrecision)
   496  }
   497  
   498  func floatDivideAnyWithError(v1 float64, v2 *EvalResult, out *EvalResult) error {
   499  	v2f, err := v2.coerceToFloat()
   500  	if err != nil {
   501  		return err
   502  	}
   503  	if v2f == 0.0 {
   504  		out.setNull()
   505  		return nil
   506  	}
   507  
   508  	result := v1 / v2f
   509  	divisorLessThanOne := v2f < 1
   510  	resultMismatch := v2f*result != v1
   511  
   512  	if divisorLessThanOne && resultMismatch {
   513  		return dataOutOfRangeError(v1, v2f, "BIGINT", "/")
   514  	}
   515  
   516  	out.setFloat(result)
   517  	return nil
   518  }
   519  
   520  func anyMinusFloat(v1 *EvalResult, v2 float64, out *EvalResult) error {
   521  	v1f, err := v1.coerceToFloat()
   522  	if err != nil {
   523  		return err
   524  	}
   525  	out.setFloat(v1f - v2)
   526  	return nil
   527  }
   528  
   529  func parseStringToFloat(str string) float64 {
   530  	str = strings.TrimSpace(str)
   531  
   532  	// We only care to parse as many of the initial float characters of the
   533  	// string as possible. This functionality is implemented in the `strconv` package
   534  	// of the standard library, but not exposed, so we hook into it.
   535  	val, _, err := hack.ParseFloatPrefix(str, 64)
   536  	if err != nil {
   537  		return 0.0
   538  	}
   539  	return val
   540  }