github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/div.go (about) 1 // Copyright 2022 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package expression 16 17 import ( 18 "fmt" 19 "math" 20 "strconv" 21 "strings" 22 "time" 23 24 "github.com/dolthub/vitess/go/vt/sqlparser" 25 "github.com/shopspring/decimal" 26 "gopkg.in/src-d/go-errors.v1" 27 28 "github.com/dolthub/go-mysql-server/sql" 29 "github.com/dolthub/go-mysql-server/sql/types" 30 ) 31 32 var ErrIntDivDataOutOfRange = errors.NewKind("BIGINT value is out of range (%s DIV %s)") 33 34 // '4 scales' are added to scale of the number on the left side of division operator at every division operation. 35 // The default value is 4, and it can be set using sysvar https://dev.mysql.com/doc/refman/8.0/en/server-system-variables.html#sysvar_div_precision_increment 36 const divPrecInc = 4 37 38 // '9 scales' are added for every non-integer divider(right side). 39 const divIntPrecInc = 9 40 41 const ERDivisionByZero = 1365 42 43 var _ ArithmeticOp = (*Div)(nil) 44 var _ sql.CollationCoercible = (*Div)(nil) 45 46 // Div expression represents "/" arithmetic operation 47 type Div struct { 48 BinaryExpressionStub 49 ops int32 50 divOps int32 51 } 52 53 // NewDiv creates a new Div / sql.Expression. 54 func NewDiv(left, right sql.Expression) *Div { 55 d := &Div{BinaryExpressionStub: BinaryExpressionStub{LeftChild: left, RightChild: right}} 56 setDivOps(d, countDivOps(d)) 57 setArithmeticOps(d, countArithmeticOps(d)) 58 return d 59 } 60 61 func (d *Div) Operator() string { 62 return sqlparser.DivStr 63 } 64 65 func (d *Div) SetOpCount(i int32) { 66 d.ops = i 67 } 68 69 func (d *Div) String() string { 70 return fmt.Sprintf("(%s / %s)", d.LeftChild, d.RightChild) 71 } 72 73 func (d *Div) DebugString() string { 74 return fmt.Sprintf("(%s / %s)", sql.DebugString(d.LeftChild), sql.DebugString(d.RightChild)) 75 } 76 77 // IsNullable implements the sql.Expression interface. 78 func (d *Div) IsNullable() bool { 79 return d.BinaryExpressionStub.IsNullable() 80 } 81 82 // Type returns the result type for this division expression. For nested division expressions, we prefer sending 83 // the result back as a float when possible, since division with floats is more efficient than division with Decimals. 84 // However, if this is the outermost division expression in an expression tree, we must return the result as a 85 // Decimal type in order to match MySQL's results exactly. 86 func (d *Div) Type() sql.Type { 87 return d.determineResultType(isOutermostDiv(d, 0, d.divOps)) 88 } 89 90 // internalType returns the internal result type for this division expression. For performance reasons, we prefer 91 // to use floats internally in division operations wherever possible, since division operations on floats can be 92 // orders of magnitude faster than division operations on Decimal types. 93 func (d *Div) internalType() sql.Type { 94 return d.determineResultType(false) 95 } 96 97 // CollationCoercibility implements the interface sql.CollationCoercible. 98 func (*Div) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { 99 return sql.Collation_binary, 5 100 } 101 102 // WithChildren implements the Expression interface. 103 func (d *Div) WithChildren(children ...sql.Expression) (sql.Expression, error) { 104 if len(children) != 2 { 105 return nil, sql.ErrInvalidChildrenNumber.New(d, len(children), 2) 106 } 107 return NewDiv(children[0], children[1]), nil 108 } 109 110 // Eval implements the Expression interface. 111 func (d *Div) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { 112 lval, rval, err := d.evalLeftRight(ctx, row) 113 if err != nil { 114 return nil, err 115 } 116 117 if lval == nil || rval == nil { 118 return nil, nil 119 } 120 121 lval, rval = d.convertLeftRight(ctx, lval, rval) 122 123 result, err := d.div(ctx, lval, rval) 124 if err != nil { 125 return nil, err 126 } 127 128 // Decimals must be rounded 129 if res, ok := result.(decimal.Decimal); ok { 130 if isOutermostArithmeticOp(d, d.ops) { 131 finalScale, _ := getFinalScale(ctx, row, d, 0) 132 return res.Round(finalScale), nil 133 } 134 } 135 136 return result, nil 137 } 138 139 func (d *Div) evalLeftRight(ctx *sql.Context, row sql.Row) (interface{}, interface{}, error) { 140 var lval, rval interface{} 141 var err error 142 143 // division used with Interval error is caught at parsing the query 144 lval, err = d.LeftChild.Eval(ctx, row) 145 if err != nil { 146 return nil, nil, err 147 } 148 149 // this operation is only done on the left value as the scale/fraction part of the leftmost value 150 // is used to calculate the scale of the final result. If the value is GetField of decimal type column 151 // the decimal value evaluated does not always match the scale of column type definition 152 if dt, ok := d.LeftChild.Type().(sql.DecimalType); ok { 153 if dVal, ok := lval.(decimal.Decimal); ok { 154 ts := int32(dt.Scale()) 155 if ts > dVal.Exponent()*-1 { 156 lval, err = decimal.NewFromString(dVal.StringFixed(ts)) 157 if err != nil { 158 return nil, nil, err 159 } 160 } 161 } 162 } 163 164 rval, err = d.RightChild.Eval(ctx, row) 165 if err != nil { 166 return nil, nil, err 167 } 168 169 return lval, rval, nil 170 } 171 172 // convertLeftRight returns the most appropriate type for left and right evaluated values, 173 // which may or may not be converted from its original type. 174 // It checks for float type column reference, then the both values converted to the same float type. 175 // Integer column references are treated as floats internally for performance reason, but the final result 176 // from the expression tree is converted to a Decimal in order to match MySQL's behavior. 177 // The decimal types of left and right value does NOT need to be the same. Both the types 178 // should be preserved. 179 func (d *Div) convertLeftRight(ctx *sql.Context, left interface{}, right interface{}) (interface{}, interface{}) { 180 typ := d.internalType() 181 lIsTimeType := types.IsTime(d.LeftChild.Type()) 182 rIsTimeType := types.IsTime(d.RightChild.Type()) 183 184 if types.IsFloat(typ) { 185 left = convertValueToType(ctx, typ, left, lIsTimeType) 186 } else { 187 left = convertToDecimalValue(left, lIsTimeType) 188 } 189 190 if types.IsFloat(typ) { 191 right = convertValueToType(ctx, typ, right, rIsTimeType) 192 } else { 193 right = convertToDecimalValue(right, rIsTimeType) 194 } 195 196 return left, right 197 } 198 199 func (d *Div) div(ctx *sql.Context, lval, rval interface{}) (interface{}, error) { 200 switch l := lval.(type) { 201 case float32: 202 switch r := rval.(type) { 203 case float32: 204 if r == 0 { 205 arithmeticWarning(ctx, ERDivisionByZero, "Division by 0") 206 return nil, nil 207 } 208 return l / r, nil 209 } 210 case float64: 211 switch r := rval.(type) { 212 case float64: 213 if r == 0 { 214 arithmeticWarning(ctx, ERDivisionByZero, "Division by 0") 215 return nil, nil 216 } 217 return l / r, nil 218 } 219 case decimal.Decimal: 220 switch r := rval.(type) { 221 case decimal.Decimal: 222 if r.Equal(decimal.NewFromInt(0)) { 223 arithmeticWarning(ctx, ERDivisionByZero, "Division by 0") 224 return nil, nil 225 } 226 227 lScale, rScale := -1*l.Exponent(), -1*r.Exponent() 228 inc := int32(math.Ceil(float64(lScale+rScale+divPrecInc) / divIntPrecInc)) 229 if lScale != 0 && rScale != 0 { 230 lInc := int32(math.Ceil(float64(lScale) / divIntPrecInc)) 231 rInc := int32(math.Ceil(float64(rScale) / divIntPrecInc)) 232 inc2 := lInc + rInc 233 if inc2 > inc { 234 inc = inc2 235 } 236 } 237 scale := inc * divIntPrecInc 238 l = l.Truncate(scale) 239 r = r.Truncate(scale) 240 241 // give it buffer of 2 additional scale to avoid the result to be rounded 242 res := l.DivRound(r, scale+2) 243 res = res.Truncate(scale) 244 return res, nil 245 } 246 } 247 248 return nil, errUnableToCast.New(lval, rval) 249 } 250 251 // determineResultType looks at the expressions in the expression tree with this division operation and determines 252 // the result type of this division expression. This involves looking at the types of the expressions in the tree, 253 // and looking for float types or Decimal types. If |outermostResult| is false, then we prefer to treat ints as floats 254 // (instead of Decimals) for performance reasons, but when |outermostResult| is true, we must treat ints as Decimals 255 // in order to match MySQL's behavior. 256 func (d *Div) determineResultType(outermostResult bool) sql.Type { 257 //TODO: what if both BindVars? should be constant folded 258 rTyp := d.RightChild.Type() 259 if types.IsDeferredType(rTyp) { 260 return rTyp 261 } 262 lTyp := d.LeftChild.Type() 263 if types.IsDeferredType(lTyp) { 264 return lTyp 265 } 266 267 if types.IsText(lTyp) || types.IsText(rTyp) { 268 return types.Float64 269 } 270 271 if types.IsJSON(lTyp) || types.IsJSON(rTyp) { 272 return types.Float64 273 } 274 275 if types.IsFloat(lTyp) || types.IsFloat(rTyp) { 276 return types.Float64 277 } 278 279 // Decimal only results from here on 280 281 if types.IsDatetimeType(lTyp) { 282 if dtType, ok := lTyp.(sql.DatetimeType); ok { 283 scale := uint8(dtType.Precision() + divPrecInc) 284 if scale > types.DecimalTypeMaxScale { 285 scale = types.DecimalTypeMaxScale 286 } 287 // TODO: determine actual precision 288 return types.MustCreateDecimalType(types.DecimalTypeMaxPrecision, scale) 289 } 290 } 291 292 if types.IsDecimal(lTyp) { 293 prec, scale := lTyp.(sql.DecimalType).Precision(), lTyp.(sql.DecimalType).Scale() 294 scale = scale + divPrecInc 295 if d.ops == -1 { 296 scale = (scale/divIntPrecInc + 1) * divIntPrecInc 297 prec = prec + scale 298 } else { 299 prec = prec + divPrecInc 300 } 301 302 if prec > types.DecimalTypeMaxPrecision { 303 prec = types.DecimalTypeMaxPrecision 304 } 305 if scale > types.DecimalTypeMaxScale { 306 scale = types.DecimalTypeMaxScale 307 } 308 return types.MustCreateDecimalType(prec, scale) 309 } 310 311 // All other types are treated as if they were integers 312 if d.ops == -1 { 313 return types.MustCreateDecimalType(types.DecimalTypeMaxPrecision, divIntPrecInc) 314 } 315 return types.MustCreateDecimalType(types.DecimalTypeMaxPrecision, divPrecInc) 316 } 317 318 // getFloatOrMaxDecimalType returns either Float64 or Decimal type with max precision and scale 319 // depending on column reference, expression types and evaluated value types. Otherwise, the return 320 // type is always max decimal type. |treatIntsAsFloats| is used for division operation optimization. 321 func getFloatOrMaxDecimalType(e sql.Expression, treatIntsAsFloats bool) sql.Type { 322 var resType sql.Type 323 var maxWhole, maxFrac uint8 324 sql.Inspect(e, func(expr sql.Expression) bool { 325 switch c := expr.(type) { 326 case *GetField: 327 ct := c.Type() 328 if treatIntsAsFloats && types.IsInteger(ct) { 329 resType = types.Float64 330 return false 331 } 332 // If there is float type column reference, the result type is always float. 333 if types.IsFloat(ct) { 334 resType = types.Float64 335 return false 336 } 337 if types.IsDecimal(ct) { 338 dt := ct.(sql.DecimalType) 339 p, s := dt.Precision(), dt.Scale() 340 if whole := p - s; whole > maxWhole { 341 maxWhole = whole 342 } 343 if s > maxFrac { 344 maxFrac = s 345 } 346 } 347 case *Convert: 348 if c.cachedDecimalType != nil { 349 p, s := GetPrecisionAndScale(c.cachedDecimalType) 350 if whole := p - s; whole > maxWhole { 351 maxWhole = whole 352 } 353 if s > maxFrac { 354 maxFrac = s 355 } 356 } 357 case *Literal: 358 if types.IsNumber(c.Type()) { 359 l, err := c.Eval(nil, nil) 360 if err == nil { 361 p, s := GetPrecisionAndScale(l) 362 if whole := p - s; whole > maxWhole { 363 maxWhole = whole 364 } 365 if s > maxFrac { 366 maxFrac = s 367 } 368 } 369 } 370 case sql.FunctionExpression: 371 // Mod.Type() calls this, so ignore it for infinite loop 372 if c.FunctionName() != "mod" { 373 resType = c.Type() 374 } 375 } 376 return true 377 }) 378 if resType == types.Float64 { 379 return resType 380 } 381 382 // defType is defined by evaluating all number literals available and defined column type. 383 defType, derr := types.CreateDecimalType(maxWhole+maxFrac, maxFrac) 384 if derr != nil { 385 return types.MustCreateDecimalType(65, 10) 386 } 387 388 return defType 389 } 390 391 // convertToDecimalValue returns value converted to decimaltype. 392 // If the value is invalid, it returns decimal 0. This function 393 // is used for 'div' or 'mod' arithmetic operation, which requires 394 // the result value to have precise precision and scale. 395 func convertToDecimalValue(val interface{}, isTimeType bool) interface{} { 396 if isTimeType { 397 val = convertTimeTypeToString(val) 398 } 399 switch v := val.(type) { 400 case bool: 401 val = 0 402 if v { 403 val = 1 404 } 405 default: 406 } 407 408 if _, ok := val.(decimal.Decimal); !ok { 409 p, s := GetPrecisionAndScale(val) 410 if p > types.DecimalTypeMaxPrecision { 411 p = types.DecimalTypeMaxPrecision 412 } 413 if s > types.DecimalTypeMaxScale { 414 s = types.DecimalTypeMaxScale 415 } 416 dtyp, err := types.CreateDecimalType(p, s) 417 if err != nil { 418 val = decimal.Zero 419 } 420 val, _, err = dtyp.Convert(val) 421 if err != nil { 422 val = decimal.Zero 423 } 424 } 425 426 return val 427 } 428 429 // countDivs returns the number of division operators in order on the left child node of the current node. 430 // This lets us count how many division operator used one after the other. E.g. 24/3/2/1 will have this structure: 431 // 432 // 'div' 433 // / \ 434 // 'div' 1 435 // / \ 436 // 'div' 2 437 // / \ 438 // 24 3 439 func countDivOps(e sql.Expression) int32 { 440 if e == nil { 441 return 0 442 } 443 if a, ok := e.(*Div); ok { 444 return countDivOps(a.LeftChild) + 1 445 } 446 if a, ok := e.(ArithmeticOp); ok { 447 return countDivOps(a.Left()) 448 } 449 return 0 450 } 451 452 // setDivs will set each node's DivScale to the number counted by countDivs. This allows us to 453 // keep track of whether the current Div expression is the last Div operation, so the result is 454 // rounded appropriately. 455 func setDivOps(e sql.Expression, divOps int32) { 456 if e == nil { 457 return 458 } 459 if a, isArithmeticOp := e.(ArithmeticOp); isArithmeticOp { 460 if d, ok := a.(*Div); ok { 461 d.divOps = divOps 462 } 463 setDivOps(a.Left(), divOps) 464 setDivOps(a.Right(), divOps) 465 } 466 if tup, ok := e.(Tuple); ok { 467 for _, expr := range tup { 468 setDivOps(expr, divOps) 469 } 470 } 471 return 472 } 473 474 // isOutermostDiv returns whether the expression we're currently evaluating is 475 // the last division operation of all continuous divisions. 476 // E.g. the top 'div' (divided by 1) is the outermost/last division that is calculated: 477 // 478 // 'div' 479 // / \ 480 // 'div' 1 481 // / \ 482 // 'div' 2 483 // / \ 484 // 24 3 485 func isOutermostDiv(e sql.Expression, d, dScale int32) bool { 486 if e == nil { 487 return false 488 } 489 490 if a, ok := e.(*Div); ok { 491 d = d + 1 492 if d == dScale { 493 return true 494 } 495 return isOutermostDiv(a.LeftChild, d, dScale) 496 } 497 498 if a, ok := e.(ArithmeticOp); ok { 499 return isOutermostDiv(a.Left(), d, dScale) 500 } 501 502 return false 503 } 504 505 // getFinalScale returns the final scale of the result value. 506 // it traverses both the left and right nodes looking for Div, Arithmetic, and Literal nodes 507 func getFinalScale(ctx *sql.Context, row sql.Row, expr sql.Expression, divOpCnt int32) (int32, bool) { 508 if expr == nil { 509 return 0, false 510 } 511 512 if div, isDiv := expr.(*Div); isDiv { 513 // TODO: there's gotta be a better way of determining if this is the leftmost div... 514 fScale := int32(divPrecInc) 515 divOpCnt = divOpCnt + 1 516 if divOpCnt == div.divOps { 517 // TODO: redundant call to Eval for LeftChild 518 lval, err := div.LeftChild.Eval(ctx, row) 519 if err != nil { 520 return 0, false 521 } 522 _, s := GetPrecisionAndScale(lval) 523 typ := div.LeftChild.Type() 524 if dt, dok := typ.(sql.DecimalType); dok { 525 ts := dt.Scale() 526 if ts > s { 527 s = ts 528 } 529 } 530 fScale += int32(s) 531 } else { 532 // We only care about left scale for divs 533 lScale, _ := getFinalScale(ctx, row, div.LeftChild, divOpCnt) 534 fScale += lScale 535 } 536 537 if fScale > types.DecimalTypeMaxScale { 538 fScale = types.DecimalTypeMaxScale 539 } 540 return fScale, true 541 } 542 543 if a, isArith := expr.(*Arithmetic); isArith { 544 lScale, lHasDiv := getFinalScale(ctx, row, a.Left(), divOpCnt) 545 rScale, rHasDiv := getFinalScale(ctx, row, a.Right(), divOpCnt) 546 var fScale int32 547 switch a.Operator() { 548 case sqlparser.PlusStr, sqlparser.MinusStr: 549 if lScale > rScale { 550 fScale = lScale 551 } else { 552 fScale = rScale 553 } 554 case sqlparser.MultStr: 555 fScale = lScale + rScale 556 } 557 if fScale > types.DecimalTypeMaxScale { 558 fScale = types.DecimalTypeMaxScale 559 } 560 return fScale, lHasDiv || rHasDiv 561 } 562 563 // TODO: this is just a guess of what mod should do with scale; test this 564 if m, isMod := expr.(*Mod); isMod { 565 fScale, leftHasDiv := getFinalScale(ctx, row, m.LeftChild, divOpCnt) 566 rScale, rightHasDiv := getFinalScale(ctx, row, m.RightChild, divOpCnt) 567 if rScale > fScale { 568 fScale = rScale 569 } 570 if fScale > types.DecimalTypeMaxScale { 571 fScale = types.DecimalTypeMaxScale 572 } 573 return fScale, leftHasDiv || rightHasDiv 574 } 575 576 // TODO: likely need a case for IntDiv 577 578 var fScale uint8 579 if lit, isLit := expr.(*Literal); isLit { 580 _, fScale = GetPrecisionAndScale(lit.value) 581 } 582 typ := expr.Type() 583 if dt, dok := typ.(sql.DecimalType); dok { 584 ts := dt.Scale() 585 if ts > fScale { 586 fScale = ts 587 } 588 } 589 590 return int32(fScale), false 591 } 592 593 // GetDecimalPrecisionAndScale returns precision and scale for given string formatted float/double number. 594 func GetDecimalPrecisionAndScale(val string) (uint8, uint8) { 595 scale := 0 596 precScale := strings.Split(strings.TrimPrefix(val, "-"), ".") 597 if len(precScale) != 1 { 598 scale = len(precScale[1]) 599 } 600 precision := len((precScale)[0]) + scale 601 return uint8(precision), uint8(scale) 602 } 603 604 // GetPrecisionAndScale converts the value to string format and parses it to get the precision and scale. 605 func GetPrecisionAndScale(val interface{}) (uint8, uint8) { 606 var str string 607 switch v := val.(type) { 608 case time.Time: 609 str = fmt.Sprintf("%v", v.In(time.UTC).Format("2006-01-02 15:04:05")) 610 case decimal.Decimal: 611 str = v.StringFixed(v.Exponent() * -1) 612 case float32: 613 d := decimal.NewFromFloat32(v) 614 str = d.StringFixed(d.Exponent() * -1) 615 case float64: 616 d := decimal.NewFromFloat(v) 617 str = d.StringFixed(d.Exponent() * -1) 618 default: 619 str = fmt.Sprintf("%v", v) 620 } 621 return GetDecimalPrecisionAndScale(str) 622 } 623 624 var _ ArithmeticOp = (*IntDiv)(nil) 625 var _ sql.CollationCoercible = (*IntDiv)(nil) 626 627 // IntDiv expression represents integer "div" arithmetic operation 628 type IntDiv struct { 629 BinaryExpressionStub 630 ops int32 631 } 632 633 // NewIntDiv creates a new IntDiv 'div' sql.Expression. 634 func NewIntDiv(left, right sql.Expression) *IntDiv { 635 a := &IntDiv{BinaryExpressionStub{LeftChild: left, RightChild: right}, 0} 636 ops := countArithmeticOps(a) 637 setArithmeticOps(a, ops) 638 return a 639 } 640 641 func (i *IntDiv) Operator() string { 642 return sqlparser.IntDivStr 643 } 644 645 func (i *IntDiv) SetOpCount(i2 int32) { 646 i.ops = i2 647 } 648 649 func (i *IntDiv) String() string { 650 return fmt.Sprintf("(%s div %s)", i.LeftChild, i.RightChild) 651 } 652 653 func (i *IntDiv) DebugString() string { 654 return fmt.Sprintf("(%s div %s)", sql.DebugString(i.LeftChild), sql.DebugString(i.RightChild)) 655 } 656 657 // IsNullable implements the sql.Expression interface. 658 func (i *IntDiv) IsNullable() bool { 659 return i.BinaryExpressionStub.IsNullable() 660 } 661 662 // Type returns the greatest type for given operation. 663 func (i *IntDiv) Type() sql.Type { 664 lTyp := i.LeftChild.Type() 665 rTyp := i.RightChild.Type() 666 667 if types.IsUnsigned(lTyp) || types.IsUnsigned(rTyp) { 668 return types.Uint64 669 } 670 671 return types.Int64 672 } 673 674 // CollationCoercibility implements the interface sql.CollationCoercible. 675 func (*IntDiv) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { 676 return sql.Collation_binary, 5 677 } 678 679 // WithChildren implements the Expression interface. 680 func (i *IntDiv) WithChildren(children ...sql.Expression) (sql.Expression, error) { 681 if len(children) != 2 { 682 return nil, sql.ErrInvalidChildrenNumber.New(i, len(children), 2) 683 } 684 return NewIntDiv(children[0], children[1]), nil 685 } 686 687 // Eval implements the Expression interface. 688 func (i *IntDiv) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { 689 lval, rval, err := i.evalLeftRight(ctx, row) 690 if err != nil { 691 return nil, err 692 } 693 694 if lval == nil || rval == nil { 695 return nil, nil 696 } 697 698 lval, rval = i.convertLeftRight(ctx, lval, rval) 699 700 return intDiv(ctx, lval, rval) 701 } 702 703 func (i *IntDiv) evalLeftRight(ctx *sql.Context, row sql.Row) (interface{}, interface{}, error) { 704 var lval, rval interface{} 705 var err error 706 707 // int division used with Interval error is caught at parsing the query 708 lval, err = i.LeftChild.Eval(ctx, row) 709 if err != nil { 710 return nil, nil, err 711 } 712 713 rval, err = i.RightChild.Eval(ctx, row) 714 if err != nil { 715 return nil, nil, err 716 } 717 718 return lval, rval, nil 719 } 720 721 // convertLeftRight return most appropriate value for left and right from evaluated value, 722 // which can might or might not be converted from its original value. 723 // It checks for float type column reference, then the both values converted to the same float types. 724 // If there is no float type column reference, both values should be handled as decimal type 725 // The decimal types of left and right value does NOT need to be the same. Both the types 726 // should be preserved. 727 func (i *IntDiv) convertLeftRight(ctx *sql.Context, left interface{}, right interface{}) (interface{}, interface{}) { 728 var typ sql.Type 729 lTyp, rTyp := i.LeftChild.Type(), i.RightChild.Type() 730 lIsTimeType := types.IsTime(lTyp) 731 rIsTimeType := types.IsTime(rTyp) 732 733 if types.IsText(lTyp) || types.IsText(rTyp) { 734 typ = types.Float64 735 } else if types.IsUnsigned(lTyp) && types.IsUnsigned(rTyp) { 736 typ = types.Uint64 737 } else if (lIsTimeType && rIsTimeType) || (types.IsSigned(lTyp) && types.IsSigned(rTyp)) { 738 typ = types.Int64 739 } else { 740 typ = types.MustCreateDecimalType(types.DecimalTypeMaxPrecision, 0) 741 } 742 743 if types.IsInteger(typ) || types.IsFloat(typ) { 744 left = convertValueToType(ctx, typ, left, lIsTimeType) 745 right = convertValueToType(ctx, typ, right, rIsTimeType) 746 } else { 747 left = convertToDecimalValue(left, lIsTimeType) 748 right = convertToDecimalValue(right, rIsTimeType) 749 } 750 751 return left, right 752 } 753 754 func intDiv(ctx *sql.Context, lval, rval interface{}) (interface{}, error) { 755 switch l := lval.(type) { 756 case uint64: 757 switch r := rval.(type) { 758 case uint64: 759 if r == 0 { 760 arithmeticWarning(ctx, ERDivisionByZero, "Division by 0") 761 return nil, nil 762 } 763 return l / r, nil 764 } 765 case int64: 766 switch r := rval.(type) { 767 case int64: 768 if r == 0 { 769 arithmeticWarning(ctx, ERDivisionByZero, "Division by 0") 770 return nil, nil 771 } 772 return l / r, nil 773 } 774 case float64: 775 switch r := rval.(type) { 776 case float64: 777 if r == 0 { 778 arithmeticWarning(ctx, ERDivisionByZero, "Division by 0") 779 return nil, nil 780 } 781 res := l / r 782 return int64(math.Floor(res)), nil 783 } 784 case decimal.Decimal: 785 switch r := rval.(type) { 786 case decimal.Decimal: 787 if r.Equal(decimal.NewFromInt(0)) { 788 arithmeticWarning(ctx, ERDivisionByZero, "Division by 0") 789 return nil, nil 790 } 791 792 // intDiv operation gets the integer part of the divided value without rounding the result with 0 precision 793 // We get division result with non-zero precision and then truncate it to get integer part without it being rounded 794 divRes := l.DivRound(r, 2).Truncate(0) 795 796 // cannot use IntPart() function of decimal.Decimal package as it returns 0 as undefined value for out of range value 797 // it causes valid result value of 0 to be the same as invalid out of range value of 0. The fraction part 798 // should not be rounded, so truncate the result wih 0 precision. 799 intPart, err := strconv.ParseInt(divRes.String(), 10, 64) 800 if err != nil { 801 return nil, ErrIntDivDataOutOfRange.New(l.StringFixed(l.Exponent()), r.StringFixed(r.Exponent())) 802 } 803 804 return intPart, nil 805 } 806 } 807 808 return nil, errUnableToCast.New(lval, rval) 809 }