github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/arithmetic.go (about) 1 // Copyright 2020-2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package expression 16 17 import ( 18 "fmt" 19 "reflect" 20 "regexp" 21 "strconv" 22 "strings" 23 "time" 24 25 "github.com/dolthub/vitess/go/mysql" 26 "github.com/dolthub/vitess/go/vt/sqlparser" 27 "github.com/shopspring/decimal" 28 errors "gopkg.in/src-d/go-errors.v1" 29 30 "github.com/dolthub/go-mysql-server/sql" 31 "github.com/dolthub/go-mysql-server/sql/types" 32 ) 33 34 var ( 35 // errUnableToCast means that we could not find common type for two arithemtic objects 36 errUnableToCast = errors.NewKind("Unable to cast between types: %T, %T") 37 38 // errUnableToEval means that we could not evaluate an expression 39 errUnableToEval = errors.NewKind("Unable to evaluate an expression: %v %s %v") 40 41 timeTypeRegex = regexp.MustCompile("[0-9]+") 42 ) 43 44 func arithmeticWarning(ctx *sql.Context, errCode int, errMsg string) { 45 if ctx != nil && ctx.Session != nil { 46 ctx.Session.Warn(&sql.Warning{ 47 Level: "Warning", 48 Code: errCode, 49 Message: errMsg, 50 }) 51 } 52 } 53 54 // ArithmeticOp implements an arithmetic expression. Since we had separate expressions 55 // for division and mod operation, we need to group all arithmetic together. Use this 56 // expression to define any arithmetic operation that is separately implemented from 57 // Arithmetic expression in the future. 58 type ArithmeticOp interface { 59 sql.Expression 60 BinaryExpression 61 SetOpCount(int32) 62 Operator() string 63 } 64 65 var _ ArithmeticOp = (*Arithmetic)(nil) 66 var _ sql.CollationCoercible = (*Arithmetic)(nil) 67 68 // Arithmetic expressions include plus, minus and multiplication (+, -, *) operations. 69 type Arithmetic struct { 70 BinaryExpressionStub 71 Op string 72 ops int32 73 } 74 75 // NewArithmetic creates a new Arithmetic sql.Expression. 76 func NewArithmetic(left, right sql.Expression, op string) *Arithmetic { 77 a := &Arithmetic{BinaryExpressionStub{LeftChild: left, RightChild: right}, op, 0} 78 ops := countArithmeticOps(a) 79 setArithmeticOps(a, ops) 80 return a 81 } 82 83 // NewPlus creates a new Arithmetic + sql.Expression. 84 func NewPlus(left, right sql.Expression) *Arithmetic { 85 return NewArithmetic(left, right, sqlparser.PlusStr) 86 } 87 88 // NewMinus creates a new Arithmetic - sql.Expression. 89 func NewMinus(left, right sql.Expression) *Arithmetic { 90 return NewArithmetic(left, right, sqlparser.MinusStr) 91 } 92 93 // NewMult creates a new Arithmetic * sql.Expression. 94 func NewMult(left, right sql.Expression) *Arithmetic { 95 return NewArithmetic(left, right, sqlparser.MultStr) 96 } 97 98 func (a *Arithmetic) Operator() string { 99 return a.Op 100 } 101 102 func (a *Arithmetic) SetOpCount(i int32) { 103 a.ops = i 104 } 105 106 func (a *Arithmetic) String() string { 107 return fmt.Sprintf("(%s %s %s)", a.LeftChild.String(), a.Op, a.RightChild.String()) 108 } 109 110 func (a *Arithmetic) DebugString() string { 111 return fmt.Sprintf("(%s %s %s)", sql.DebugString(a.LeftChild), a.Op, sql.DebugString(a.RightChild)) 112 } 113 114 // IsNullable implements the sql.Expression interface. 115 func (a *Arithmetic) IsNullable() bool { 116 if types.IsDatetimeType(a.Type()) || types.IsTimestampType(a.Type()) { 117 return true 118 } 119 120 return a.BinaryExpressionStub.IsNullable() 121 } 122 123 // Type returns the greatest type for given operation. 124 func (a *Arithmetic) Type() sql.Type { 125 //TODO: what if both BindVars? should be constant folded 126 rTyp := a.RightChild.Type() 127 if types.IsDeferredType(rTyp) { 128 return rTyp 129 } 130 lTyp := a.LeftChild.Type() 131 if types.IsDeferredType(lTyp) { 132 return lTyp 133 } 134 135 // applies for + and - ops 136 if isInterval(a.LeftChild) || isInterval(a.RightChild) { 137 // TODO: need to use the precision stored in datetimeType; something like 138 // return types.MustCreateDatetimeType(sqltypes.Datetime, 0) 139 return types.Datetime 140 } 141 142 if types.IsText(lTyp) || types.IsText(rTyp) { 143 return types.Float64 144 } 145 146 if types.IsJSON(lTyp) || types.IsJSON(rTyp) { 147 return types.Float64 148 } 149 150 if types.IsFloat(lTyp) || types.IsFloat(rTyp) { 151 return types.Float64 152 } 153 154 if types.IsYear(lTyp) && types.IsYear(rTyp) { 155 // MySQL just returns the largest int that fits 156 return types.Uint64 157 } 158 159 // Bit types are integers 160 if types.IsBit(lTyp) { 161 lTyp = types.Int64 162 } 163 if types.IsBit(rTyp) { 164 rTyp = types.Int64 165 } 166 167 // Dates are Integers 168 if types.IsDateType(lTyp) { 169 lTyp = types.Int64 170 } 171 if types.IsDateType(rTyp) { 172 rTyp = types.Int64 173 } 174 175 // Datetime(0) is treated as Int64, otherwise as Decimal 176 if types.IsDatetimeType(lTyp) { 177 if dtType, ok := lTyp.(sql.DatetimeType); ok { 178 scale := uint8(dtType.Precision()) 179 if scale == 0 { 180 lTyp = types.Int64 181 } else { 182 lTyp = types.MustCreateDecimalType(types.DecimalTypeMaxPrecision, scale) 183 } 184 } 185 } 186 if types.IsDatetimeType(rTyp) { 187 if dtType, ok := rTyp.(sql.DatetimeType); ok { 188 scale := uint8(dtType.Precision()) 189 if scale == 0 { 190 rTyp = types.Int64 191 } else { 192 rTyp = types.MustCreateDecimalType(types.DecimalTypeMaxPrecision, scale) 193 } 194 } 195 } 196 197 if types.IsUnsigned(lTyp) && types.IsUnsigned(rTyp) { 198 return types.Uint64 199 } 200 201 if types.IsInteger(lTyp) && types.IsInteger(rTyp) { 202 return types.Int64 203 } 204 205 if types.IsDecimal(lTyp) && !types.IsDecimal(rTyp) { 206 return lTyp 207 } 208 209 if types.IsDecimal(rTyp) && !types.IsDecimal(lTyp) { 210 return rTyp 211 } 212 213 if types.IsDecimal(lTyp) && types.IsDecimal(rTyp) { 214 lPrec := lTyp.(sql.DecimalType).Precision() 215 lScale := lTyp.(sql.DecimalType).Scale() 216 rPrec := rTyp.(sql.DecimalType).Precision() 217 rScale := rTyp.(sql.DecimalType).Scale() 218 219 var prec, scale uint8 220 if lPrec > rPrec { 221 prec = lPrec 222 } else { 223 prec = rPrec 224 } 225 226 switch a.Op { 227 case sqlparser.PlusStr, sqlparser.MinusStr: 228 if lScale > rScale { 229 scale = lScale 230 } else { 231 scale = rScale 232 } 233 prec = prec + scale 234 case sqlparser.MultStr: 235 scale = lScale + rScale 236 prec = prec + scale 237 } 238 239 if prec > types.DecimalTypeMaxPrecision { 240 prec = types.DecimalTypeMaxPrecision 241 } 242 if scale > types.DecimalTypeMaxScale { 243 scale = types.DecimalTypeMaxScale 244 } 245 246 return types.MustCreateDecimalType(prec, scale) 247 } 248 249 // When in doubt return float64 250 return types.Float64 251 } 252 253 // CollationCoercibility implements the interface sql.CollationCoercible. 254 func (*Arithmetic) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { 255 return sql.Collation_binary, 5 256 } 257 258 // WithChildren implements the Expression interface. 259 func (a *Arithmetic) WithChildren(children ...sql.Expression) (sql.Expression, error) { 260 if len(children) != 2 { 261 return nil, sql.ErrInvalidChildrenNumber.New(a, len(children), 2) 262 } 263 // sanity check 264 switch strings.ToLower(a.Op) { 265 case sqlparser.DivStr: 266 return NewDiv(children[0], children[1]), nil 267 case sqlparser.ModStr: 268 return NewMod(children[0], children[1]), nil 269 } 270 return NewArithmetic(children[0], children[1], a.Op), nil 271 } 272 273 // Eval implements the Expression interface. 274 func (a *Arithmetic) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { 275 lval, rval, err := a.evalLeftRight(ctx, row) 276 if err != nil { 277 return nil, err 278 } 279 280 if lval == nil || rval == nil { 281 return nil, nil 282 } 283 284 lval, rval, err = a.convertLeftRight(ctx, lval, rval) 285 if err != nil { 286 return nil, err 287 } 288 289 var result interface{} 290 switch strings.ToLower(a.Op) { 291 case sqlparser.PlusStr: 292 result, err = plus(lval, rval) 293 case sqlparser.MinusStr: 294 result, err = minus(lval, rval) 295 case sqlparser.MultStr: 296 result, err = mult(lval, rval) 297 } 298 299 if err != nil { 300 return nil, err 301 } 302 303 // Decimals must be rounded 304 if res, ok := result.(decimal.Decimal); ok { 305 if isOutermostArithmeticOp(a, a.ops) { 306 finalScale, hasDiv := getFinalScale(ctx, row, a, 0) 307 if hasDiv { 308 // TODO: should always round regardless; we have bad Decimal defaults 309 return res.Round(finalScale), nil 310 } 311 } 312 } 313 314 return result, nil 315 } 316 317 func (a *Arithmetic) evalLeftRight(ctx *sql.Context, row sql.Row) (interface{}, interface{}, error) { 318 var lval, rval interface{} 319 var err error 320 321 if i, ok := a.LeftChild.(*Interval); ok { 322 lval, err = i.EvalDelta(ctx, row) 323 if err != nil { 324 return nil, nil, err 325 } 326 } else { 327 lval, err = a.LeftChild.Eval(ctx, row) 328 if err != nil { 329 return nil, nil, err 330 } 331 } 332 333 if i, ok := a.RightChild.(*Interval); ok { 334 rval, err = i.EvalDelta(ctx, row) 335 if err != nil { 336 return nil, nil, err 337 } 338 } else { 339 rval, err = a.RightChild.Eval(ctx, row) 340 if err != nil { 341 return nil, nil, err 342 } 343 } 344 345 return lval, rval, nil 346 } 347 348 func (a *Arithmetic) convertLeftRight(ctx *sql.Context, left interface{}, right interface{}) (interface{}, interface{}, error) { 349 typ := a.Type() 350 351 lIsTimeType := types.IsTime(a.LeftChild.Type()) 352 rIsTimeType := types.IsTime(a.RightChild.Type()) 353 354 if i, ok := left.(*TimeDelta); ok { 355 left = i 356 } else { 357 // these are the types we specifically want to capture from we get from Type() 358 if types.IsInteger(typ) || types.IsFloat(typ) || types.IsTime(typ) { 359 left = convertValueToType(ctx, typ, left, lIsTimeType) 360 } else { 361 left = convertToDecimalValue(left, lIsTimeType) 362 } 363 } 364 365 if i, ok := right.(*TimeDelta); ok { 366 right = i 367 } else { 368 // these are the types we specifically want to capture from we get from Type() 369 if types.IsInteger(typ) || types.IsFloat(typ) || types.IsTime(typ) { 370 right = convertValueToType(ctx, typ, right, rIsTimeType) 371 } else { 372 right = convertToDecimalValue(right, rIsTimeType) 373 } 374 } 375 376 return left, right, nil 377 } 378 379 func isInterval(expr sql.Expression) bool { 380 _, ok := expr.(*Interval) 381 return ok 382 } 383 384 // countArithmeticOps returns the number of arithmetic operators under the current node. 385 // This lets us count how many arithmetic operators used one after the other 386 func countArithmeticOps(e sql.Expression) int32 { 387 if e == nil { 388 return 0 389 } 390 391 if a, ok := e.(ArithmeticOp); ok { 392 return countArithmeticOps(a.Left()) + countArithmeticOps(a.Right()) + 1 393 } 394 395 return 0 396 } 397 398 // setArithmeticOps will set ops number with number counted by countArithmeticOps. This allows 399 // us to keep track of whether the expression is the last arithmetic operation. 400 func setArithmeticOps(e sql.Expression, ops int32) { 401 if e == nil { 402 return 403 } 404 405 if a, ok := e.(ArithmeticOp); ok { 406 a.SetOpCount(ops) 407 setArithmeticOps(a.Left(), ops) 408 setArithmeticOps(a.Right(), ops) 409 } 410 411 if tup, ok := e.(Tuple); ok { 412 for _, expr := range tup { 413 setArithmeticOps(expr, ops) 414 } 415 } 416 417 return 418 } 419 420 // isOutermostArithmeticOp return whether the expression we're currently on is 421 // the last arithmetic operation of all continuous arithmetic operations. 422 func isOutermostArithmeticOp(e sql.Expression, opScale int32) bool { 423 return opScale == countArithmeticOps(e) 424 } 425 426 // convertValueToType returns |val| converted into type |typ|. If the value is 427 // invalid and cannot be converted to the given type, it returns nil, and it should be 428 // interpreted as value of 0. For time types, all the numbers are parsed up to seconds only. 429 // E.g: `2022-11-10 12:14:36` is parsed into `20221110121436` and `2022-03-24` is parsed into `20220324`. 430 func convertValueToType(ctx *sql.Context, typ sql.Type, val interface{}, isTimeType bool) interface{} { 431 var cval interface{} 432 if isTimeType { 433 val = convertTimeTypeToString(val) 434 } 435 436 cval, _, err := typ.Convert(val) 437 if err != nil { 438 arithmeticWarning(ctx, mysql.ERTruncatedWrongValue, fmt.Sprintf("Truncated incorrect %s value: '%v'", typ.String(), val)) 439 // the value is interpreted as 0, but we need to match the type of the other valid value 440 // to avoid additional conversion, the nil value is handled in each operation 441 } 442 return cval 443 } 444 445 // convertTimeTypeToString returns string value parsed from either time.Time or string 446 // representation. all the numbers are parsed up to seconds only. The location can be 447 // different between two time.Time values, so we set it to default UTC location before 448 // parsing. E.g: 449 // `2022-11-10 12:14:36` is parsed into `20221110121436` 450 // `2022-03-24` is parsed into `20220324`. 451 func convertTimeTypeToString(val interface{}) interface{} { 452 if t, ok := val.(time.Time); ok { 453 val = t.In(time.UTC).Format("2006-01-02 15:04:05") 454 } 455 if t, ok := val.(string); ok { 456 nums := timeTypeRegex.FindAllString(t, -1) 457 val = strings.Join(nums, "") 458 } 459 460 return val 461 } 462 463 func plus(lval, rval interface{}) (interface{}, error) { 464 switch l := lval.(type) { 465 case uint8: 466 switch r := rval.(type) { 467 case uint8: 468 return l + r, nil 469 } 470 case int8: 471 switch r := rval.(type) { 472 case int8: 473 return l + r, nil 474 } 475 case uint16: 476 switch r := rval.(type) { 477 case uint16: 478 return l + r, nil 479 } 480 case int16: 481 switch r := rval.(type) { 482 case int16: 483 return l + r, nil 484 } 485 case uint32: 486 switch r := rval.(type) { 487 case uint32: 488 return l + r, nil 489 } 490 case int32: 491 switch r := rval.(type) { 492 case int32: 493 return l + r, nil 494 } 495 case uint64: 496 switch r := rval.(type) { 497 case uint64: 498 return l + r, nil 499 } 500 case int64: 501 switch r := rval.(type) { 502 case int64: 503 return l + r, nil 504 } 505 case float32: 506 switch r := rval.(type) { 507 case float32: 508 return l + r, nil 509 } 510 case float64: 511 switch r := rval.(type) { 512 case float64: 513 return l + r, nil 514 } 515 case decimal.Decimal: 516 switch r := rval.(type) { 517 case decimal.Decimal: 518 return l.Add(r), nil 519 } 520 case time.Time: 521 switch r := rval.(type) { 522 case *TimeDelta: 523 return types.ValidateTime(r.Add(l)), nil 524 case time.Time: 525 return l.Unix() + r.Unix(), nil 526 } 527 case *TimeDelta: 528 switch r := rval.(type) { 529 case time.Time: 530 return types.ValidateTime(l.Add(r)), nil 531 } 532 } 533 534 return nil, errUnableToCast.New(lval, rval) 535 } 536 537 func minus(lval, rval interface{}) (interface{}, error) { 538 switch l := lval.(type) { 539 case uint8: 540 switch r := rval.(type) { 541 case uint8: 542 return l - r, nil 543 } 544 case int8: 545 switch r := rval.(type) { 546 case int8: 547 return l - r, nil 548 } 549 case uint16: 550 switch r := rval.(type) { 551 case uint16: 552 return l - r, nil 553 } 554 case int16: 555 switch r := rval.(type) { 556 case int16: 557 return l - r, nil 558 } 559 case uint32: 560 switch r := rval.(type) { 561 case uint32: 562 return l - r, nil 563 } 564 case int32: 565 switch r := rval.(type) { 566 case int32: 567 return l - r, nil 568 } 569 case uint64: 570 switch r := rval.(type) { 571 case uint64: 572 return l - r, nil 573 } 574 case int64: 575 switch r := rval.(type) { 576 case int64: 577 return l - r, nil 578 } 579 case float32: 580 switch r := rval.(type) { 581 case float32: 582 return l - r, nil 583 } 584 case float64: 585 switch r := rval.(type) { 586 case float64: 587 return l - r, nil 588 } 589 case decimal.Decimal: 590 switch r := rval.(type) { 591 case decimal.Decimal: 592 return l.Sub(r), nil 593 } 594 case time.Time: 595 switch r := rval.(type) { 596 case *TimeDelta: 597 return types.ValidateTime(r.Sub(l)), nil 598 case time.Time: 599 return l.Unix() - r.Unix(), nil 600 } 601 } 602 603 return nil, errUnableToCast.New(lval, rval) 604 } 605 606 func mult(lval, rval interface{}) (interface{}, error) { 607 switch l := lval.(type) { 608 case uint8: 609 switch r := rval.(type) { 610 case uint8: 611 return l * r, nil 612 } 613 case int8: 614 switch r := rval.(type) { 615 case int8: 616 return l * r, nil 617 } 618 case uint16: 619 switch r := rval.(type) { 620 case uint16: 621 return l * r, nil 622 } 623 case int16: 624 switch r := rval.(type) { 625 case int16: 626 return l * r, nil 627 } 628 case uint32: 629 switch r := rval.(type) { 630 case uint32: 631 return l * r, nil 632 } 633 case int32: 634 switch r := rval.(type) { 635 case int32: 636 return l * r, nil 637 } 638 case uint64: 639 switch r := rval.(type) { 640 case uint64: 641 return l * r, nil 642 } 643 case int64: 644 switch r := rval.(type) { 645 case int64: 646 return l * r, nil 647 } 648 case float32: 649 switch r := rval.(type) { 650 case float32: 651 return l * r, nil 652 } 653 case float64: 654 switch r := rval.(type) { 655 case float64: 656 return l * r, nil 657 } 658 case decimal.Decimal: 659 switch r := rval.(type) { 660 case decimal.Decimal: 661 return l.Mul(r), nil 662 } 663 } 664 665 return nil, errUnableToCast.New(lval, rval) 666 } 667 668 // UnaryMinus is an unary minus operator. 669 type UnaryMinus struct { 670 UnaryExpression 671 } 672 673 var _ sql.Expression = (*UnaryMinus)(nil) 674 var _ sql.CollationCoercible = (*UnaryMinus)(nil) 675 676 // NewUnaryMinus creates a new UnaryMinus expression node. 677 func NewUnaryMinus(child sql.Expression) *UnaryMinus { 678 return &UnaryMinus{UnaryExpression{Child: child}} 679 } 680 681 // Eval implements the sql.Expression interface. 682 func (e *UnaryMinus) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { 683 child, err := e.Child.Eval(ctx, row) 684 if err != nil { 685 return nil, err 686 } 687 688 if child == nil { 689 return nil, nil 690 } 691 692 if !types.IsNumber(e.Child.Type()) { 693 child, err = decimal.NewFromString(fmt.Sprintf("%v", child)) 694 if err != nil { 695 child = 0.0 696 } 697 } 698 699 switch n := child.(type) { 700 case float64: 701 return -n, nil 702 case float32: 703 return -n, nil 704 case int: 705 return -n, nil 706 case int8: 707 return -n, nil 708 case int16: 709 return -n, nil 710 case int32: 711 return -n, nil 712 case int64: 713 return -n, nil 714 case uint: 715 return -int(n), nil 716 case uint8: 717 return -int8(n), nil 718 case uint16: 719 return -int16(n), nil 720 case uint32: 721 return -int32(n), nil 722 case uint64: 723 return -int64(n), nil 724 case decimal.Decimal: 725 return n.Neg(), err 726 case string: 727 // try getting int out of string value 728 i, iErr := strconv.ParseInt(n, 10, 64) 729 if iErr != nil { 730 return nil, sql.ErrInvalidType.New(reflect.TypeOf(n)) 731 } 732 return -i, nil 733 default: 734 return nil, sql.ErrInvalidType.New(reflect.TypeOf(n)) 735 } 736 } 737 738 // Type implements the sql.Expression interface. 739 func (e *UnaryMinus) Type() sql.Type { 740 typ := e.Child.Type() 741 if !types.IsNumber(typ) { 742 return types.Float64 743 } 744 745 if typ == types.Uint32 { 746 return types.Int32 747 } 748 749 if typ == types.Uint64 { 750 return types.Int64 751 } 752 753 return e.Child.Type() 754 } 755 756 // CollationCoercibility implements the interface sql.CollationCoercible. 757 func (*UnaryMinus) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { 758 return sql.Collation_binary, 5 759 } 760 761 func (e *UnaryMinus) String() string { 762 return fmt.Sprintf("-%s", e.Child) 763 } 764 765 // WithChildren implements the Expression interface. 766 func (e *UnaryMinus) WithChildren(children ...sql.Expression) (sql.Expression, error) { 767 if len(children) != 1 { 768 return nil, sql.ErrInvalidChildrenNumber.New(e, len(children), 1) 769 } 770 return NewUnaryMinus(children[0]), nil 771 }