vitess.io/vitess@v0.16.2/go/vt/vtgate/evalengine/arithmetic.go (about) 1 /* 2 Copyright 2019 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package evalengine 18 19 import ( 20 "bytes" 21 "strconv" 22 "strings" 23 24 "vitess.io/vitess/go/hack" 25 "vitess.io/vitess/go/sqltypes" 26 vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" 27 "vitess.io/vitess/go/vt/vterrors" 28 "vitess.io/vitess/go/vt/vtgate/evalengine/internal/decimal" 29 ) 30 31 // evalengine represents a numeric value extracted from 32 // a Value, used for arithmetic operations. 33 var zeroBytes = []byte("0") 34 35 func dataOutOfRangeError(v1, v2 any, typ, sign string) error { 36 return vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.DataOutOfRange, "%s value is out of range in '(%v %s %v)'", typ, v1, sign, v2) 37 } 38 39 // FormatFloat formats a float64 as a byte string in a similar way to what MySQL does 40 func FormatFloat(typ sqltypes.Type, f float64) []byte { 41 return AppendFloat(nil, typ, f) 42 } 43 44 func AppendFloat(buf []byte, typ sqltypes.Type, f float64) []byte { 45 format := byte('g') 46 if typ == sqltypes.Decimal { 47 format = 'f' 48 } 49 50 // the float printer in MySQL does not add a positive sign before 51 // the exponent for positive exponents, but the Golang printer does 52 // do that, and there's no way to customize it, so we must strip the 53 // redundant positive sign manually 54 // e.g. 1.234E+56789 -> 1.234E56789 55 fstr := strconv.AppendFloat(buf, f, format, -1, 64) 56 if idx := bytes.IndexByte(fstr, 'e'); idx >= 0 { 57 if fstr[idx+1] == '+' { 58 fstr = append(fstr[:idx+1], fstr[idx+2:]...) 59 } 60 } 61 62 return fstr 63 } 64 65 // Add adds two values together 66 // if v1 or v2 is null, then it returns null 67 func Add(v1, v2 sqltypes.Value) (sqltypes.Value, error) { 68 if v1.IsNull() || v2.IsNull() { 69 return sqltypes.NULL, nil 70 } 71 72 var lv1, lv2, out EvalResult 73 if err := lv1.setValue(v1, collationNumeric); err != nil { 74 return sqltypes.NULL, err 75 } 76 if err := lv2.setValue(v2, collationNumeric); err != nil { 77 return sqltypes.NULL, err 78 } 79 80 err := addNumericWithError(&lv1, &lv2, &out) 81 if err != nil { 82 return sqltypes.NULL, err 83 } 84 return out.Value(), nil 85 } 86 87 // Subtract takes two values and subtracts them 88 func Subtract(v1, v2 sqltypes.Value) (sqltypes.Value, error) { 89 if v1.IsNull() || v2.IsNull() { 90 return sqltypes.NULL, nil 91 } 92 93 var lv1, lv2, out EvalResult 94 if err := lv1.setValue(v1, collationNumeric); err != nil { 95 return sqltypes.NULL, err 96 } 97 if err := lv2.setValue(v2, collationNumeric); err != nil { 98 return sqltypes.NULL, err 99 } 100 101 err := subtractNumericWithError(&lv1, &lv2, &out) 102 if err != nil { 103 return sqltypes.NULL, err 104 } 105 106 return out.Value(), nil 107 } 108 109 // Multiply takes two values and multiplies it together 110 func Multiply(v1, v2 sqltypes.Value) (sqltypes.Value, error) { 111 if v1.IsNull() || v2.IsNull() { 112 return sqltypes.NULL, nil 113 } 114 115 var lv1, lv2, out EvalResult 116 if err := lv1.setValue(v1, collationNumeric); err != nil { 117 return sqltypes.NULL, err 118 } 119 if err := lv2.setValue(v2, collationNumeric); err != nil { 120 return sqltypes.NULL, err 121 } 122 123 err := multiplyNumericWithError(&lv1, &lv2, &out) 124 if err != nil { 125 return sqltypes.NULL, err 126 } 127 128 return out.Value(), nil 129 } 130 131 // Divide (Float) for MySQL. Replicates behavior of "/" operator 132 func Divide(v1, v2 sqltypes.Value) (sqltypes.Value, error) { 133 if v1.IsNull() || v2.IsNull() { 134 return sqltypes.NULL, nil 135 } 136 137 var lv1, lv2, out EvalResult 138 if err := lv1.setValue(v1, collationNumeric); err != nil { 139 return sqltypes.NULL, err 140 } 141 if err := lv2.setValue(v2, collationNumeric); err != nil { 142 return sqltypes.NULL, err 143 } 144 145 err := divideNumericWithError(&lv1, &lv2, true, &out) 146 if err != nil { 147 return sqltypes.NULL, err 148 } 149 150 return out.Value(), nil 151 } 152 153 // NullSafeAdd adds two Values in a null-safe manner. A null value 154 // is treated as 0. If both values are null, then a null is returned. 155 // If both values are not null, a numeric value is built 156 // from each input: Signed->int64, Unsigned->uint64, Float->float64. 157 // Otherwise the 'best type fit' is chosen for the number: int64 or float64. 158 // OpAddition is performed by upgrading types as needed, or in case 159 // of overflow: int64->uint64, int64->float64, uint64->float64. 160 // Unsigned ints can only be added to positive ints. After the 161 // addition, if one of the input types was Decimal, then 162 // a Decimal is built. Otherwise, the final type of the 163 // result is preserved. 164 func NullSafeAdd(v1, v2 sqltypes.Value, resultType sqltypes.Type) (sqltypes.Value, error) { 165 if v1.IsNull() { 166 v1 = sqltypes.MakeTrusted(resultType, zeroBytes) 167 } 168 if v2.IsNull() { 169 v2 = sqltypes.MakeTrusted(resultType, zeroBytes) 170 } 171 172 var lv1, lv2, out EvalResult 173 if err := lv1.setValue(v1, collationNumeric); err != nil { 174 return sqltypes.NULL, err 175 } 176 if err := lv2.setValue(v2, collationNumeric); err != nil { 177 return sqltypes.NULL, err 178 } 179 180 err := addNumericWithError(&lv1, &lv2, &out) 181 if err != nil { 182 return sqltypes.NULL, err 183 } 184 return out.toSQLValue(resultType), nil 185 } 186 187 func addNumericWithError(v1, v2, out *EvalResult) error { 188 v1, v2 = makeNumericAndPrioritize(v1, v2) 189 switch v1.typeof() { 190 case sqltypes.Int64: 191 return intPlusIntWithError(v1.uint64(), v2.uint64(), out) 192 case sqltypes.Uint64: 193 switch v2.typeof() { 194 case sqltypes.Int64: 195 return uintPlusIntWithError(v1.uint64(), v2.uint64(), out) 196 case sqltypes.Uint64: 197 return uintPlusUintWithError(v1.uint64(), v2.uint64(), out) 198 } 199 case sqltypes.Decimal: 200 decimalPlusAny(v1.decimal(), v1.length_, v2, out) 201 return nil 202 case sqltypes.Float64: 203 return floatPlusAny(v1.float64(), v2, out) 204 } 205 return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid arithmetic between: %s %s", v1.value().String(), v2.value().String()) 206 } 207 208 func subtractNumericWithError(v1, v2, out *EvalResult) error { 209 v1.makeNumeric() 210 v2.makeNumeric() 211 switch v1.typeof() { 212 case sqltypes.Int64: 213 switch v2.typeof() { 214 case sqltypes.Int64: 215 return intMinusIntWithError(v1.uint64(), v2.uint64(), out) 216 case sqltypes.Uint64: 217 return intMinusUintWithError(v1.uint64(), v2.uint64(), out) 218 case sqltypes.Float64: 219 return anyMinusFloat(v1, v2.float64(), out) 220 case sqltypes.Decimal: 221 anyMinusDecimal(v1, v2.decimal(), v2.length_, out) 222 return nil 223 } 224 case sqltypes.Uint64: 225 switch v2.typeof() { 226 case sqltypes.Int64: 227 return uintMinusIntWithError(v1.uint64(), v2.uint64(), out) 228 case sqltypes.Uint64: 229 return uintMinusUintWithError(v1.uint64(), v2.uint64(), out) 230 case sqltypes.Float64: 231 return anyMinusFloat(v1, v2.float64(), out) 232 case sqltypes.Decimal: 233 anyMinusDecimal(v1, v2.decimal(), v2.length_, out) 234 return nil 235 } 236 case sqltypes.Float64: 237 return floatMinusAny(v1.float64(), v2, out) 238 case sqltypes.Decimal: 239 switch v2.typeof() { 240 case sqltypes.Float64: 241 return anyMinusFloat(v1, v2.float64(), out) 242 default: 243 decimalMinusAny(v1.decimal(), v1.length_, v2, out) 244 return nil 245 } 246 } 247 return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid arithmetic between: %s %s", v1.value().String(), v2.value().String()) 248 } 249 250 func multiplyNumericWithError(v1, v2, out *EvalResult) error { 251 v1, v2 = makeNumericAndPrioritize(v1, v2) 252 switch v1.typeof() { 253 case sqltypes.Int64: 254 return intTimesIntWithError(v1.uint64(), v2.uint64(), out) 255 case sqltypes.Uint64: 256 switch v2.typeof() { 257 case sqltypes.Int64: 258 return uintTimesIntWithError(v1.uint64(), v2.uint64(), out) 259 case sqltypes.Uint64: 260 return uintTimesUintWithError(v1.uint64(), v2.uint64(), out) 261 } 262 case sqltypes.Float64: 263 return floatTimesAny(v1.float64(), v2, out) 264 case sqltypes.Decimal: 265 decimalTimesAny(v1.decimal(), v1.length_, v2, out) 266 return nil 267 } 268 return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid arithmetic between: %s %s", v1.value().String(), v2.value().String()) 269 270 } 271 272 func divideNumericWithError(v1, v2 *EvalResult, precise bool, out *EvalResult) error { 273 v1.makeNumeric() 274 v2.makeNumeric() 275 if !precise && v1.typeof() != sqltypes.Decimal && v2.typeof() != sqltypes.Decimal { 276 switch v1.typeof() { 277 case sqltypes.Int64: 278 return floatDivideAnyWithError(float64(v1.int64()), v2, out) 279 280 case sqltypes.Uint64: 281 return floatDivideAnyWithError(float64(v1.uint64()), v2, out) 282 283 case sqltypes.Float64: 284 return floatDivideAnyWithError(v1.float64(), v2, out) 285 } 286 } 287 switch { 288 case v1.typeof() == sqltypes.Float64: 289 return floatDivideAnyWithError(v1.float64(), v2, out) 290 case v2.typeof() == sqltypes.Float64: 291 v1f, err := v1.coerceToFloat() 292 if err != nil { 293 return err 294 } 295 return floatDivideAnyWithError(v1f, v2, out) 296 default: 297 decimalDivide(v1, v2, divPrecisionIncrement, out) 298 return nil 299 } 300 } 301 302 // makeNumericAndPrioritize reorders the input parameters 303 // to be Float64, Decimal, Uint64, Int64. 304 func makeNumericAndPrioritize(i1, i2 *EvalResult) (*EvalResult, *EvalResult) { 305 i1.makeNumeric() 306 i2.makeNumeric() 307 switch i1.typeof() { 308 case sqltypes.Int64: 309 if i2.typeof() == sqltypes.Uint64 || i2.typeof() == sqltypes.Float64 || i2.typeof() == sqltypes.Decimal { 310 return i2, i1 311 } 312 case sqltypes.Uint64: 313 if i2.typeof() == sqltypes.Float64 || i2.typeof() == sqltypes.Decimal { 314 return i2, i1 315 } 316 case sqltypes.Decimal: 317 if i2.typeof() == sqltypes.Float64 { 318 return i2, i1 319 } 320 } 321 return i1, i2 322 } 323 324 func intPlusIntWithError(v1u, v2u uint64, out *EvalResult) error { 325 v1, v2 := int64(v1u), int64(v2u) 326 result := v1 + v2 327 if (result > v1) != (v2 > 0) { 328 return dataOutOfRangeError(v1, v2, "BIGINT", "+") 329 } 330 out.setInt64(result) 331 return nil 332 } 333 334 func intMinusIntWithError(v1u, v2u uint64, out *EvalResult) error { 335 v1, v2 := int64(v1u), int64(v2u) 336 result := v1 - v2 337 338 if (result < v1) != (v2 > 0) { 339 return dataOutOfRangeError(v1, v2, "BIGINT", "-") 340 } 341 out.setInt64(result) 342 return nil 343 } 344 345 func intTimesIntWithError(v1u, v2u uint64, out *EvalResult) error { 346 v1, v2 := int64(v1u), int64(v2u) 347 result := v1 * v2 348 if v1 != 0 && result/v1 != v2 { 349 return dataOutOfRangeError(v1, v2, "BIGINT", "*") 350 } 351 out.setInt64(result) 352 return nil 353 354 } 355 356 func intMinusUintWithError(v1u uint64, v2 uint64, out *EvalResult) error { 357 v1 := int64(v1u) 358 if v1 < 0 || v1 < int64(v2) { 359 return dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "-") 360 } 361 return uintMinusUintWithError(v1u, v2, out) 362 } 363 364 func uintPlusIntWithError(v1 uint64, v2u uint64, out *EvalResult) error { 365 v2 := int64(v2u) 366 result := v1 + uint64(v2) 367 if v2 < 0 && v1 < uint64(-v2) || v2 > 0 && (result < v1 || result < uint64(v2)) { 368 return dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "+") 369 } 370 // convert to int -> uint is because for numeric operators (such as + or -) 371 // where one of the operands is an unsigned integer, the result is unsigned by default. 372 out.setUint64(result) 373 return nil 374 } 375 376 func uintMinusIntWithError(v1 uint64, v2u uint64, out *EvalResult) error { 377 v2 := int64(v2u) 378 if int64(v1) < v2 && v2 > 0 { 379 return dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "-") 380 } 381 // uint - (- int) = uint + int 382 if v2 < 0 { 383 return uintPlusIntWithError(v1, uint64(-v2), out) 384 } 385 return uintMinusUintWithError(v1, uint64(v2), out) 386 } 387 388 func uintTimesIntWithError(v1 uint64, v2u uint64, out *EvalResult) error { 389 v2 := int64(v2u) 390 if v1 == 0 || v2 == 0 { 391 out.setUint64(0) 392 return nil 393 } 394 if v2 < 0 || int64(v1) < 0 { 395 return dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "*") 396 } 397 return uintTimesUintWithError(v1, uint64(v2), out) 398 } 399 400 func uintPlusUintWithError(v1, v2 uint64, out *EvalResult) error { 401 result := v1 + v2 402 if result < v1 || result < v2 { 403 return dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "+") 404 } 405 out.setUint64(result) 406 return nil 407 } 408 409 func uintMinusUintWithError(v1, v2 uint64, out *EvalResult) error { 410 result := v1 - v2 411 if v2 > v1 { 412 return dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "-") 413 } 414 out.setUint64(result) 415 return nil 416 } 417 418 func uintTimesUintWithError(v1, v2 uint64, out *EvalResult) error { 419 if v1 == 0 || v2 == 0 { 420 out.setUint64(0) 421 return nil 422 } 423 result := v1 * v2 424 if result < v2 || result < v1 { 425 return dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "*") 426 } 427 out.setUint64(result) 428 return nil 429 } 430 431 func floatPlusAny(v1 float64, v2 *EvalResult, out *EvalResult) error { 432 v2f, err := v2.coerceToFloat() 433 if err != nil { 434 return err 435 } 436 add := v1 + v2f 437 out.setFloat(add) 438 return nil 439 } 440 441 func floatMinusAny(v1 float64, v2 *EvalResult, out *EvalResult) error { 442 v2f, err := v2.coerceToFloat() 443 if err != nil { 444 return err 445 } 446 out.setFloat(v1 - v2f) 447 return nil 448 } 449 450 func floatTimesAny(v1 float64, v2 *EvalResult, out *EvalResult) error { 451 v2f, err := v2.coerceToFloat() 452 if err != nil { 453 return err 454 } 455 out.setFloat(v1 * v2f) 456 return nil 457 } 458 459 func maxprec(a, b int32) int32 { 460 if a > b { 461 return a 462 } 463 return b 464 } 465 466 func decimalPlusAny(v1 decimal.Decimal, f1 int32, v2 *EvalResult, out *EvalResult) { 467 v2d := v2.coerceToDecimal() 468 out.setDecimal(v1.Add(v2d), maxprec(f1, v2.length_)) 469 } 470 471 func decimalMinusAny(v1 decimal.Decimal, f1 int32, v2 *EvalResult, out *EvalResult) { 472 v2d := v2.coerceToDecimal() 473 out.setDecimal(v1.Sub(v2d), maxprec(f1, v2.length_)) 474 } 475 476 func anyMinusDecimal(v1 *EvalResult, v2 decimal.Decimal, f2 int32, out *EvalResult) { 477 v1d := v1.coerceToDecimal() 478 out.setDecimal(v1d.Sub(v2), maxprec(v1.length_, f2)) 479 } 480 481 func decimalTimesAny(v1 decimal.Decimal, f1 int32, v2 *EvalResult, out *EvalResult) { 482 v2d := v2.coerceToDecimal() 483 out.setDecimal(v1.Mul(v2d), maxprec(f1, v2.length_)) 484 } 485 486 const divPrecisionIncrement = 4 487 488 func decimalDivide(v1, v2 *EvalResult, incrPrecision int32, out *EvalResult) { 489 v1d := v1.coerceToDecimal() 490 v2d := v2.coerceToDecimal() 491 if v2d.IsZero() { 492 out.setNull() 493 return 494 } 495 out.setDecimal(v1d.Div(v2d, incrPrecision), v1.length_+incrPrecision) 496 } 497 498 func floatDivideAnyWithError(v1 float64, v2 *EvalResult, out *EvalResult) error { 499 v2f, err := v2.coerceToFloat() 500 if err != nil { 501 return err 502 } 503 if v2f == 0.0 { 504 out.setNull() 505 return nil 506 } 507 508 result := v1 / v2f 509 divisorLessThanOne := v2f < 1 510 resultMismatch := v2f*result != v1 511 512 if divisorLessThanOne && resultMismatch { 513 return dataOutOfRangeError(v1, v2f, "BIGINT", "/") 514 } 515 516 out.setFloat(result) 517 return nil 518 } 519 520 func anyMinusFloat(v1 *EvalResult, v2 float64, out *EvalResult) error { 521 v1f, err := v1.coerceToFloat() 522 if err != nil { 523 return err 524 } 525 out.setFloat(v1f - v2) 526 return nil 527 } 528 529 func parseStringToFloat(str string) float64 { 530 str = strings.TrimSpace(str) 531 532 // We only care to parse as many of the initial float characters of the 533 // string as possible. This functionality is implemented in the `strconv` package 534 // of the standard library, but not exposed, so we hook into it. 535 val, _, err := hack.ParseFloatPrefix(str, 64) 536 if err != nil { 537 return 0.0 538 } 539 return val 540 }