github.com/insionng/yougam@v0.0.0-20170714101924-2bc18d833463/libraries/pingcap/tidb/evaluator/evaluator_binop.go (about) 1 // Copyright 2015 PingCAP, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 package evaluator 15 16 import ( 17 "math" 18 19 "github.com/insionng/yougam/libraries/juju/errors" 20 "github.com/insionng/yougam/libraries/pingcap/tidb/ast" 21 "github.com/insionng/yougam/libraries/pingcap/tidb/parser/opcode" 22 "github.com/insionng/yougam/libraries/pingcap/tidb/util/types" 23 ) 24 25 const ( 26 zeroI64 int64 = 0 27 oneI64 int64 = 1 28 ) 29 30 func (e *Evaluator) binaryOperation(o *ast.BinaryOperationExpr) bool { 31 switch o.Op { 32 case opcode.AndAnd, opcode.OrOr, opcode.LogicXor: 33 return e.handleLogicOperation(o) 34 case opcode.LT, opcode.LE, opcode.GE, opcode.GT, opcode.EQ, opcode.NE, opcode.NullEQ: 35 return e.handleComparisonOp(o) 36 case opcode.RightShift, opcode.LeftShift, opcode.And, opcode.Or, opcode.Xor: 37 return e.handleBitOp(o) 38 case opcode.Plus, opcode.Minus, opcode.Mod, opcode.Div, opcode.Mul, opcode.IntDiv: 39 return e.handleArithmeticOp(o) 40 default: 41 e.err = ErrInvalidOperation 42 return false 43 } 44 } 45 46 func (e *Evaluator) handleLogicOperation(o *ast.BinaryOperationExpr) bool { 47 switch o.Op { 48 case opcode.AndAnd: 49 return e.handleAndAnd(o) 50 case opcode.OrOr: 51 return e.handleOrOr(o) 52 case opcode.LogicXor: 53 return e.handleXor(o) 54 default: 55 e.err = ErrInvalidOperation.Gen("unkown operator %s", o.Op) 56 return false 57 } 58 } 59 60 func (e *Evaluator) handleAndAnd(o *ast.BinaryOperationExpr) bool { 61 leftDatum := o.L.GetDatum() 62 rightDatum := o.R.GetDatum() 63 if leftDatum.Kind() != types.KindNull { 64 x, err := leftDatum.ToBool() 65 if err != nil { 66 e.err = errors.Trace(err) 67 return false 68 } else if x == 0 { 69 // false && any other types is false 70 o.SetInt64(x) 71 return true 72 } 73 } 74 if rightDatum.Kind() != types.KindNull { 75 y, err := rightDatum.ToBool() 76 if err != nil { 77 e.err = errors.Trace(err) 78 return false 79 } else if y == 0 { 80 o.SetInt64(y) 81 return true 82 } 83 } 84 if leftDatum.Kind() == types.KindNull || rightDatum.Kind() == types.KindNull { 85 o.SetNull() 86 return true 87 } 88 o.SetInt64(int64(1)) 89 return true 90 } 91 92 func (e *Evaluator) handleOrOr(o *ast.BinaryOperationExpr) bool { 93 leftDatum := o.L.GetDatum() 94 if leftDatum.Kind() != types.KindNull { 95 x, err := leftDatum.ToBool() 96 if err != nil { 97 e.err = errors.Trace(err) 98 return false 99 } else if x == 1 { 100 // true || any other types is true. 101 o.SetInt64(x) 102 return true 103 } 104 } 105 righDatum := o.R.GetDatum() 106 if righDatum.Kind() != types.KindNull { 107 y, err := righDatum.ToBool() 108 if err != nil { 109 e.err = errors.Trace(err) 110 return false 111 } else if y == 1 { 112 o.SetInt64(y) 113 return true 114 } 115 } 116 if leftDatum.Kind() == types.KindNull || righDatum.Kind() == types.KindNull { 117 o.SetNull() 118 return true 119 } 120 o.SetInt64(int64(0)) 121 return true 122 } 123 124 func (e *Evaluator) handleXor(o *ast.BinaryOperationExpr) bool { 125 leftDatum := o.L.GetDatum() 126 righDatum := o.R.GetDatum() 127 if leftDatum.Kind() == types.KindNull || righDatum.Kind() == types.KindNull { 128 o.SetNull() 129 return true 130 } 131 x, err := leftDatum.ToBool() 132 if err != nil { 133 e.err = errors.Trace(err) 134 return false 135 } 136 137 y, err := righDatum.ToBool() 138 if err != nil { 139 e.err = errors.Trace(err) 140 return false 141 } 142 if x == y { 143 o.SetInt64(int64(0)) 144 } else { 145 o.SetInt64(int64(1)) 146 } 147 return true 148 } 149 150 func (e *Evaluator) handleComparisonOp(o *ast.BinaryOperationExpr) bool { 151 a, b := types.CoerceDatum(*o.L.GetDatum(), *o.R.GetDatum()) 152 if a.Kind() == types.KindNull || b.Kind() == types.KindNull { 153 // for <=>, if a and b are both nil, return true. 154 // if a or b is nil, return false. 155 if o.Op == opcode.NullEQ { 156 if a.Kind() == types.KindNull && b.Kind() == types.KindNull { 157 o.SetInt64(oneI64) 158 } else { 159 o.SetInt64(zeroI64) 160 } 161 } else { 162 o.SetNull() 163 } 164 return true 165 } 166 167 n, err := a.CompareDatum(b) 168 169 if err != nil { 170 e.err = errors.Trace(err) 171 return false 172 } 173 174 r, err := getCompResult(o.Op, n) 175 if err != nil { 176 e.err = errors.Trace(err) 177 return false 178 } 179 if r { 180 o.SetInt64(oneI64) 181 } else { 182 o.SetInt64(zeroI64) 183 } 184 return true 185 } 186 187 func getCompResult(op opcode.Op, value int) (bool, error) { 188 switch op { 189 case opcode.LT: 190 return value < 0, nil 191 case opcode.LE: 192 return value <= 0, nil 193 case opcode.GE: 194 return value >= 0, nil 195 case opcode.GT: 196 return value > 0, nil 197 case opcode.EQ: 198 return value == 0, nil 199 case opcode.NE: 200 return value != 0, nil 201 case opcode.NullEQ: 202 return value == 0, nil 203 default: 204 return false, ErrInvalidOperation.Gen("invalid op %v in comparision operation", op) 205 } 206 } 207 208 func (e *Evaluator) handleBitOp(o *ast.BinaryOperationExpr) bool { 209 a, b := types.CoerceDatum(*o.L.GetDatum(), *o.R.GetDatum()) 210 211 if a.Kind() == types.KindNull || b.Kind() == types.KindNull { 212 o.SetNull() 213 return true 214 } 215 216 x, err := a.ToInt64() 217 if err != nil { 218 e.err = errors.Trace(err) 219 return false 220 } 221 222 y, err := b.ToInt64() 223 if err != nil { 224 e.err = errors.Trace(err) 225 return false 226 } 227 228 // use a int64 for bit operator, return uint64 229 switch o.Op { 230 case opcode.And: 231 o.SetUint64(uint64(x & y)) 232 case opcode.Or: 233 o.SetUint64(uint64(x | y)) 234 case opcode.Xor: 235 o.SetUint64(uint64(x ^ y)) 236 case opcode.RightShift: 237 o.SetUint64(uint64(x) >> uint64(y)) 238 case opcode.LeftShift: 239 o.SetUint64(uint64(x) << uint64(y)) 240 default: 241 e.err = ErrInvalidOperation.Gen("invalid op %v in bit operation", o.Op) 242 return false 243 } 244 return true 245 } 246 247 func (e *Evaluator) handleArithmeticOp(o *ast.BinaryOperationExpr) bool { 248 a, err := coerceArithmetic(*o.L.GetDatum()) 249 if err != nil { 250 e.err = errors.Trace(err) 251 return false 252 } 253 254 b, err := coerceArithmetic(*o.R.GetDatum()) 255 if err != nil { 256 e.err = errors.Trace(err) 257 return false 258 } 259 260 a, b = types.CoerceDatum(a, b) 261 if a.Kind() == types.KindNull || b.Kind() == types.KindNull { 262 o.SetNull() 263 return true 264 } 265 266 var result types.Datum 267 switch o.Op { 268 case opcode.Plus: 269 result, e.err = computePlus(a, b) 270 case opcode.Minus: 271 result, e.err = computeMinus(a, b) 272 case opcode.Mul: 273 result, e.err = computeMul(a, b) 274 case opcode.Div: 275 result, e.err = computeDiv(a, b) 276 case opcode.Mod: 277 result, e.err = computeMod(a, b) 278 case opcode.IntDiv: 279 result, e.err = computeIntDiv(a, b) 280 default: 281 e.err = ErrInvalidOperation.Gen("invalid op %v in arithmetic operation", o.Op) 282 return false 283 } 284 o.SetDatum(result) 285 return e.err == nil 286 } 287 288 func computePlus(a, b types.Datum) (d types.Datum, err error) { 289 switch a.Kind() { 290 case types.KindInt64: 291 switch b.Kind() { 292 case types.KindInt64: 293 r, err1 := types.AddInt64(a.GetInt64(), b.GetInt64()) 294 d.SetInt64(r) 295 return d, errors.Trace(err1) 296 case types.KindUint64: 297 r, err1 := types.AddInteger(b.GetUint64(), a.GetInt64()) 298 d.SetUint64(r) 299 return d, errors.Trace(err1) 300 } 301 case types.KindUint64: 302 switch b.Kind() { 303 case types.KindInt64: 304 r, err1 := types.AddInteger(a.GetUint64(), b.GetInt64()) 305 d.SetUint64(r) 306 return d, errors.Trace(err1) 307 case types.KindUint64: 308 r, err1 := types.AddUint64(a.GetUint64(), b.GetUint64()) 309 d.SetUint64(r) 310 return d, errors.Trace(err1) 311 } 312 case types.KindFloat64: 313 switch b.Kind() { 314 case types.KindFloat64: 315 r := a.GetFloat64() + b.GetFloat64() 316 d.SetFloat64(r) 317 return d, nil 318 } 319 case types.KindMysqlDecimal: 320 switch b.Kind() { 321 case types.KindMysqlDecimal: 322 r := a.GetMysqlDecimal().Add(b.GetMysqlDecimal()) 323 d.SetMysqlDecimal(r) 324 return d, nil 325 } 326 } 327 _, err = types.InvOp2(a.GetValue(), b.GetValue(), opcode.Plus) 328 return d, err 329 } 330 331 func computeMinus(a, b types.Datum) (d types.Datum, err error) { 332 switch a.Kind() { 333 case types.KindInt64: 334 switch b.Kind() { 335 case types.KindInt64: 336 r, err1 := types.SubInt64(a.GetInt64(), b.GetInt64()) 337 d.SetInt64(r) 338 return d, errors.Trace(err1) 339 case types.KindUint64: 340 r, err1 := types.SubIntWithUint(a.GetInt64(), b.GetUint64()) 341 d.SetUint64(r) 342 return d, errors.Trace(err1) 343 } 344 case types.KindUint64: 345 switch b.Kind() { 346 case types.KindInt64: 347 r, err1 := types.SubUintWithInt(a.GetUint64(), b.GetInt64()) 348 d.SetUint64(r) 349 return d, errors.Trace(err1) 350 case types.KindUint64: 351 r, err1 := types.SubUint64(a.GetUint64(), b.GetUint64()) 352 d.SetUint64(r) 353 return d, errors.Trace(err1) 354 } 355 case types.KindFloat64: 356 switch b.Kind() { 357 case types.KindFloat64: 358 r := a.GetFloat64() - b.GetFloat64() 359 d.SetFloat64(r) 360 return d, nil 361 } 362 case types.KindMysqlDecimal: 363 switch b.Kind() { 364 case types.KindMysqlDecimal: 365 r := a.GetMysqlDecimal().Sub(b.GetMysqlDecimal()) 366 d.SetMysqlDecimal(r) 367 return d, nil 368 } 369 } 370 _, err = types.InvOp2(a.GetValue(), b.GetValue(), opcode.Minus) 371 return d, errors.Trace(err) 372 } 373 374 func computeMul(a, b types.Datum) (d types.Datum, err error) { 375 switch a.Kind() { 376 case types.KindInt64: 377 switch b.Kind() { 378 case types.KindInt64: 379 r, err1 := types.MulInt64(a.GetInt64(), b.GetInt64()) 380 d.SetInt64(r) 381 return d, errors.Trace(err1) 382 case types.KindUint64: 383 r, err1 := types.MulInteger(b.GetUint64(), a.GetInt64()) 384 d.SetUint64(r) 385 return d, errors.Trace(err1) 386 } 387 case types.KindUint64: 388 switch b.Kind() { 389 case types.KindInt64: 390 r, err1 := types.MulInteger(a.GetUint64(), b.GetInt64()) 391 d.SetUint64(r) 392 return d, errors.Trace(err1) 393 case types.KindUint64: 394 r, err1 := types.MulUint64(a.GetUint64(), b.GetUint64()) 395 d.SetUint64(r) 396 return d, errors.Trace(err1) 397 } 398 case types.KindFloat64: 399 switch b.Kind() { 400 case types.KindFloat64: 401 r := a.GetFloat64() * b.GetFloat64() 402 d.SetFloat64(r) 403 return d, nil 404 } 405 case types.KindMysqlDecimal: 406 switch b.Kind() { 407 case types.KindMysqlDecimal: 408 r := a.GetMysqlDecimal().Mul(b.GetMysqlDecimal()) 409 d.SetMysqlDecimal(r) 410 return d, nil 411 } 412 } 413 414 _, err = types.InvOp2(a.GetValue(), b.GetValue(), opcode.Mul) 415 return d, errors.Trace(err) 416 } 417 418 func computeDiv(a, b types.Datum) (d types.Datum, err error) { 419 // MySQL support integer division Div and division operator / 420 // we use opcode.Div for division operator and will use another for integer division later. 421 // for division operator, we will use float64 for calculation. 422 switch a.Kind() { 423 case types.KindFloat64: 424 y, err1 := b.ToFloat64() 425 if err1 != nil { 426 return d, errors.Trace(err1) 427 } 428 429 if y == 0 { 430 return d, nil 431 } 432 433 x := a.GetFloat64() 434 d.SetFloat64(x / y) 435 return d, nil 436 default: 437 // the scale of the result is the scale of the first operand plus 438 // the value of the div_precision_increment system variable (which is 4 by default) 439 // we will use 4 here 440 xa, err1 := a.ToDecimal() 441 if err != nil { 442 return d, errors.Trace(err1) 443 } 444 445 xb, err1 := b.ToDecimal() 446 if err1 != nil { 447 return d, errors.Trace(err1) 448 } 449 if f, _ := xb.Float64(); f == 0 { 450 // division by zero return null 451 return d, nil 452 } 453 454 d.SetMysqlDecimal(xa.Div(xb)) 455 return d, nil 456 } 457 } 458 459 func computeMod(a, b types.Datum) (d types.Datum, err error) { 460 switch a.Kind() { 461 case types.KindInt64: 462 x := a.GetInt64() 463 switch b.Kind() { 464 case types.KindInt64: 465 y := b.GetInt64() 466 if y == 0 { 467 return d, nil 468 } 469 d.SetInt64(x % y) 470 return d, nil 471 case types.KindUint64: 472 y := b.GetUint64() 473 if y == 0 { 474 return d, nil 475 } else if x < 0 { 476 d.SetInt64(-int64(uint64(-x) % y)) 477 // first is int64, return int64. 478 return d, nil 479 } 480 d.SetInt64(int64(uint64(x) % y)) 481 return d, nil 482 } 483 case types.KindUint64: 484 x := a.GetUint64() 485 switch b.Kind() { 486 case types.KindInt64: 487 y := b.GetInt64() 488 if y == 0 { 489 return d, nil 490 } else if y < 0 { 491 // first is uint64, return uint64. 492 d.SetUint64(uint64(x % uint64(-y))) 493 return d, nil 494 } 495 d.SetUint64(x % uint64(y)) 496 return d, nil 497 case types.KindUint64: 498 y := b.GetUint64() 499 if y == 0 { 500 return d, nil 501 } 502 d.SetUint64(x % y) 503 return d, nil 504 } 505 case types.KindFloat64: 506 x := a.GetFloat64() 507 switch b.Kind() { 508 case types.KindFloat64: 509 y := b.GetFloat64() 510 if y == 0 { 511 return d, nil 512 } 513 d.SetFloat64(math.Mod(x, y)) 514 return d, nil 515 } 516 case types.KindMysqlDecimal: 517 x := a.GetMysqlDecimal() 518 switch b.Kind() { 519 case types.KindMysqlDecimal: 520 y := b.GetMysqlDecimal() 521 xf, _ := x.Float64() 522 yf, _ := y.Float64() 523 if yf == 0 { 524 return d, nil 525 } 526 d.SetFloat64(math.Mod(xf, yf)) 527 return d, nil 528 } 529 } 530 _, err = types.InvOp2(a.GetValue(), b.GetValue(), opcode.Mod) 531 return d, errors.Trace(err) 532 } 533 534 func computeIntDiv(a, b types.Datum) (d types.Datum, err error) { 535 switch a.Kind() { 536 case types.KindInt64: 537 x := a.GetInt64() 538 switch b.Kind() { 539 case types.KindInt64: 540 y := b.GetInt64() 541 if y == 0 { 542 return d, nil 543 } 544 r, err1 := types.DivInt64(x, y) 545 d.SetInt64(r) 546 return d, errors.Trace(err1) 547 case types.KindUint64: 548 y := b.GetUint64() 549 if y == 0 { 550 return d, nil 551 } 552 r, err1 := types.DivIntWithUint(x, y) 553 d.SetUint64(r) 554 return d, errors.Trace(err1) 555 } 556 case types.KindUint64: 557 x := a.GetUint64() 558 switch b.Kind() { 559 case types.KindInt64: 560 y := b.GetInt64() 561 if y == 0 { 562 return d, nil 563 } 564 r, err1 := types.DivUintWithInt(x, y) 565 d.SetUint64(r) 566 return d, errors.Trace(err1) 567 case types.KindUint64: 568 y := b.GetUint64() 569 if y == 0 { 570 return d, nil 571 } 572 d.SetUint64(x / y) 573 return d, nil 574 } 575 } 576 577 // if any is none integer, use decimal to calculate 578 x, err := a.ToDecimal() 579 if err != nil { 580 return d, errors.Trace(err) 581 } 582 583 y, err := b.ToDecimal() 584 if err != nil { 585 return d, errors.Trace(err) 586 } 587 588 if f, _ := y.Float64(); f == 0 { 589 return d, nil 590 } 591 592 d.SetInt64(x.Div(y).IntPart()) 593 return d, nil 594 } 595 596 func coerceArithmetic(a types.Datum) (d types.Datum, err error) { 597 switch a.Kind() { 598 case types.KindString, types.KindBytes: 599 // MySQL will convert string to float for arithmetic operation 600 f, err := types.StrToFloat(a.GetString()) 601 if err != nil { 602 return d, errors.Trace(err) 603 } 604 d.SetFloat64(f) 605 return d, errors.Trace(err) 606 case types.KindMysqlTime: 607 // if time has no precision, return int64 608 t := a.GetMysqlTime() 609 de := t.ToNumber() 610 if t.Fsp == 0 { 611 d.SetInt64(de.IntPart()) 612 return d, nil 613 } 614 d.SetMysqlDecimal(de) 615 return d, nil 616 case types.KindMysqlDuration: 617 // if duration has no precision, return int64 618 du := a.GetMysqlDuration() 619 de := du.ToNumber() 620 if du.Fsp == 0 { 621 d.SetInt64(de.IntPart()) 622 return d, nil 623 } 624 d.SetMysqlDecimal(de) 625 return d, nil 626 case types.KindMysqlHex: 627 d.SetFloat64(a.GetMysqlHex().ToNumber()) 628 return d, nil 629 case types.KindMysqlBit: 630 d.SetFloat64(a.GetMysqlBit().ToNumber()) 631 return d, nil 632 case types.KindMysqlEnum: 633 d.SetFloat64(a.GetMysqlEnum().ToNumber()) 634 return d, nil 635 case types.KindMysqlSet: 636 d.SetFloat64(a.GetMysqlSet().ToNumber()) 637 return d, nil 638 default: 639 return a, nil 640 } 641 }