github.com/insionng/yougam@v0.0.0-20170714101924-2bc18d833463/libraries/pingcap/tidb/evaluator/evaluator.go (about) 1 // Copyright 2015 PingCAP, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 package evaluator 15 16 import ( 17 "strings" 18 19 "github.com/insionng/yougam/libraries/juju/errors" 20 "github.com/insionng/yougam/libraries/pingcap/tidb/ast" 21 "github.com/insionng/yougam/libraries/pingcap/tidb/context" 22 "github.com/insionng/yougam/libraries/pingcap/tidb/mysql" 23 "github.com/insionng/yougam/libraries/pingcap/tidb/parser/opcode" 24 "github.com/insionng/yougam/libraries/pingcap/tidb/sessionctx/variable" 25 "github.com/insionng/yougam/libraries/pingcap/tidb/terror" 26 "github.com/insionng/yougam/libraries/pingcap/tidb/util/types" 27 ) 28 29 // Error instances. 30 var ( 31 ErrInvalidOperation = terror.ClassEvaluator.New(CodeInvalidOperation, "invalid operation") 32 ) 33 34 // Error codes. 35 const ( 36 CodeInvalidOperation terror.ErrCode = 1 37 ) 38 39 // Eval evaluates an expression to a datum. 40 func Eval(ctx context.Context, expr ast.ExprNode) (d types.Datum, err error) { 41 if ast.IsEvaluated(expr) { 42 return *expr.GetDatum(), nil 43 } 44 e := &Evaluator{ctx: ctx} 45 expr.Accept(e) 46 if e.err != nil { 47 return d, errors.Trace(e.err) 48 } 49 if ast.IsPreEvaluable(expr) && (expr.GetFlag()&ast.FlagHasFunc == 0) { 50 expr.SetFlag(expr.GetFlag() | ast.FlagPreEvaluated) 51 } 52 return *expr.GetDatum(), nil 53 } 54 55 // EvalBool evalueates an expression to a boolean value. 56 func EvalBool(ctx context.Context, expr ast.ExprNode) (bool, error) { 57 val, err := Eval(ctx, expr) 58 if err != nil { 59 return false, errors.Trace(err) 60 } 61 if val.Kind() == types.KindNull { 62 return false, nil 63 } 64 65 i, err := val.ToBool() 66 if err != nil { 67 return false, errors.Trace(err) 68 } 69 return i != 0, nil 70 } 71 72 func boolToInt64(v bool) int64 { 73 if v { 74 return int64(1) 75 } 76 return int64(0) 77 } 78 79 // Evaluator is an ast Visitor that evaluates an expression. 80 type Evaluator struct { 81 ctx context.Context 82 err error 83 } 84 85 // Enter implements ast.Visitor interface. 86 func (e *Evaluator) Enter(in ast.Node) (out ast.Node, skipChildren bool) { 87 return in, false 88 } 89 90 // Leave implements ast.Visitor interface. 91 func (e *Evaluator) Leave(in ast.Node) (out ast.Node, ok bool) { 92 switch v := in.(type) { 93 case *ast.AggregateFuncExpr: 94 ok = e.aggregateFunc(v) 95 case *ast.BetweenExpr: 96 ok = e.between(v) 97 case *ast.BinaryOperationExpr: 98 ok = e.binaryOperation(v) 99 case *ast.CaseExpr: 100 ok = e.caseExpr(v) 101 case *ast.ColumnName: 102 ok = true 103 case *ast.ColumnNameExpr: 104 ok = e.columnName(v) 105 case *ast.CompareSubqueryExpr: 106 ok = e.compareSubquery(v) 107 case *ast.DefaultExpr: 108 ok = e.defaultExpr(v) 109 case *ast.ExistsSubqueryExpr: 110 ok = e.existsSubquery(v) 111 case *ast.FuncCallExpr: 112 ok = e.funcCall(v) 113 case *ast.FuncCastExpr: 114 ok = e.funcCast(v) 115 case *ast.IsNullExpr: 116 ok = e.isNull(v) 117 case *ast.IsTruthExpr: 118 ok = e.isTruth(v) 119 case *ast.ParamMarkerExpr: 120 ok = e.paramMarker(v) 121 case *ast.ParenthesesExpr: 122 ok = e.parentheses(v) 123 case *ast.PatternInExpr: 124 ok = e.patternIn(v) 125 case *ast.PatternLikeExpr: 126 ok = e.patternLike(v) 127 case *ast.PatternRegexpExpr: 128 ok = e.patternRegexp(v) 129 case *ast.PositionExpr: 130 ok = e.position(v) 131 case *ast.RowExpr: 132 ok = e.row(v) 133 case *ast.SubqueryExpr: 134 ok = e.subqueryExpr(v) 135 case *ast.UnaryOperationExpr: 136 ok = e.unaryOperation(v) 137 case *ast.ValueExpr: 138 ok = true 139 case *ast.ValuesExpr: 140 ok = e.values(v) 141 case *ast.VariableExpr: 142 ok = e.variable(v) 143 case *ast.WhenClause: 144 ok = true 145 } 146 out = in 147 return 148 } 149 150 func (e *Evaluator) between(v *ast.BetweenExpr) bool { 151 var l, r ast.ExprNode 152 op := opcode.AndAnd 153 154 if v.Not { 155 // v < lv || v > rv 156 op = opcode.OrOr 157 l = &ast.BinaryOperationExpr{Op: opcode.LT, L: v.Expr, R: v.Left} 158 r = &ast.BinaryOperationExpr{Op: opcode.GT, L: v.Expr, R: v.Right} 159 } else { 160 // v >= lv && v <= rv 161 l = &ast.BinaryOperationExpr{Op: opcode.GE, L: v.Expr, R: v.Left} 162 r = &ast.BinaryOperationExpr{Op: opcode.LE, L: v.Expr, R: v.Right} 163 } 164 ast.MergeChildrenFlags(l, v.Expr, v.Left) 165 ast.MergeChildrenFlags(l, v.Expr, v.Right) 166 ret := &ast.BinaryOperationExpr{Op: op, L: l, R: r} 167 ast.MergeChildrenFlags(ret, l, r) 168 ret.Accept(e) 169 if e.err != nil { 170 return false 171 } 172 v.SetDatum(*ret.GetDatum()) 173 return true 174 } 175 176 func (e *Evaluator) caseExpr(v *ast.CaseExpr) bool { 177 tmp := types.NewDatum(boolToInt64(true)) 178 target := &tmp 179 if v.Value != nil { 180 target = v.Value.GetDatum() 181 } 182 if target.Kind() != types.KindNull { 183 for _, val := range v.WhenClauses { 184 cmp, err := target.CompareDatum(*val.Expr.GetDatum()) 185 if err != nil { 186 e.err = errors.Trace(err) 187 return false 188 } 189 if cmp == 0 { 190 v.SetDatum(*val.Result.GetDatum()) 191 return true 192 } 193 } 194 } 195 if v.ElseClause != nil { 196 v.SetDatum(*v.ElseClause.GetDatum()) 197 } else { 198 v.SetNull() 199 } 200 return true 201 } 202 203 func (e *Evaluator) columnName(v *ast.ColumnNameExpr) bool { 204 v.SetDatum(*v.Refer.Expr.GetDatum()) 205 return true 206 } 207 208 func (e *Evaluator) defaultExpr(v *ast.DefaultExpr) bool { 209 return true 210 } 211 212 func (e *Evaluator) compareSubquery(cs *ast.CompareSubqueryExpr) bool { 213 lv := *cs.L.GetDatum() 214 if lv.Kind() == types.KindNull { 215 cs.SetNull() 216 return true 217 } 218 x, err := e.checkResult(cs, lv, cs.R.GetDatum().GetRow()) 219 if err != nil { 220 e.err = errors.Trace(err) 221 return false 222 } 223 cs.SetDatum(x) 224 return true 225 } 226 227 func (e *Evaluator) checkResult(cs *ast.CompareSubqueryExpr, lv types.Datum, result []types.Datum) (types.Datum, error) { 228 if cs.All { 229 return e.checkAllResult(cs, lv, result) 230 } 231 return e.checkAnyResult(cs, lv, result) 232 } 233 234 func (e *Evaluator) checkAllResult(cs *ast.CompareSubqueryExpr, lv types.Datum, result []types.Datum) (d types.Datum, err error) { 235 hasNull := false 236 for _, v := range result { 237 if v.Kind() == types.KindNull { 238 hasNull = true 239 continue 240 } 241 242 comRes, err1 := lv.CompareDatum(v) 243 if err1 != nil { 244 return d, errors.Trace(err1) 245 } 246 247 res, err1 := getCompResult(cs.Op, comRes) 248 if err1 != nil { 249 return d, errors.Trace(err1) 250 } 251 if !res { 252 d.SetInt64(boolToInt64(false)) 253 return d, nil 254 } 255 } 256 if hasNull { 257 // If no matched but we get null, return null. 258 // Like `insert t (c) values (1),(2),(null)`, then 259 // `select 3 > all (select c from t)`, returns null. 260 return d, nil 261 } 262 d.SetInt64(boolToInt64(true)) 263 return d, nil 264 } 265 266 func (e *Evaluator) checkAnyResult(cs *ast.CompareSubqueryExpr, lv types.Datum, result []types.Datum) (d types.Datum, err error) { 267 hasNull := false 268 for _, v := range result { 269 if v.Kind() == types.KindNull { 270 hasNull = true 271 continue 272 } 273 274 comRes, err1 := lv.CompareDatum(v) 275 if err1 != nil { 276 return d, errors.Trace(err1) 277 } 278 279 res, err1 := getCompResult(cs.Op, comRes) 280 if err1 != nil { 281 return d, errors.Trace(err1) 282 } 283 if res { 284 d.SetInt64(boolToInt64(true)) 285 return d, nil 286 } 287 } 288 289 if hasNull { 290 // If no matched but we get null, return null. 291 // Like `insert t (c) values (1),(2),(null)`, then 292 // `select 0 > any (select c from t)`, returns null. 293 return d, nil 294 } 295 296 d.SetInt64(boolToInt64(false)) 297 return d, nil 298 } 299 300 func (e *Evaluator) existsSubquery(v *ast.ExistsSubqueryExpr) bool { 301 d := v.Sel.GetDatum() 302 if d.Kind() == types.KindNull { 303 v.SetInt64(0) 304 return true 305 } 306 rows := d.GetRow() 307 if len(rows) > 0 { 308 v.SetInt64(1) 309 } else { 310 v.SetInt64(0) 311 } 312 return true 313 } 314 315 // Evaluate SubqueryExpr. 316 // Get the value from v.SubQuery and set it to v. 317 func (e *Evaluator) subqueryExpr(v *ast.SubqueryExpr) bool { 318 if v.Evaluated && !v.Correlated { 319 // Subquery do not use outer context should only evaluate once. 320 return true 321 } 322 err := EvalSubquery(e.ctx, v) 323 if err != nil { 324 e.err = errors.Trace(err) 325 return false 326 } 327 return true 328 } 329 330 // EvalSubquery evaluates a subquery. 331 func EvalSubquery(ctx context.Context, v *ast.SubqueryExpr) error { 332 if v.SubqueryExec != nil { 333 rowCount := 2 334 if v.MultiRows { 335 rowCount = -1 336 } else if v.Exists { 337 rowCount = 1 338 } 339 rows, err := v.SubqueryExec.EvalRows(ctx, rowCount) 340 if err != nil { 341 return errors.Trace(err) 342 } 343 if v.MultiRows || v.Exists { 344 v.GetDatum().SetRow(rows) 345 v.Evaluated = true 346 return nil 347 } 348 switch len(rows) { 349 case 0: 350 v.SetNull() 351 case 1: 352 v.SetDatum(rows[0]) 353 default: 354 return errors.New("Subquery returns more than 1 row") 355 } 356 } 357 v.Evaluated = true 358 return nil 359 } 360 361 func (e *Evaluator) checkInList(not bool, in types.Datum, list []types.Datum) (d types.Datum) { 362 hasNull := false 363 for _, v := range list { 364 if v.Kind() == types.KindNull { 365 hasNull = true 366 continue 367 } 368 369 a, b := types.CoerceDatum(in, v) 370 r, err := a.CompareDatum(b) 371 if err != nil { 372 e.err = errors.Trace(err) 373 return d 374 } 375 if r == 0 { 376 if !not { 377 d.SetInt64(1) 378 return d 379 } 380 d.SetInt64(0) 381 return d 382 } 383 } 384 385 if hasNull { 386 // if no matched but we got null in In, return null 387 // e.g 1 in (null, 2, 3) returns null 388 return d 389 } 390 if not { 391 d.SetInt64(1) 392 return d 393 } 394 d.SetInt64(0) 395 return d 396 } 397 398 func (e *Evaluator) patternIn(n *ast.PatternInExpr) bool { 399 lhs := *n.Expr.GetDatum() 400 if lhs.Kind() == types.KindNull { 401 n.SetNull() 402 return true 403 } 404 if n.Sel == nil { 405 ds := make([]types.Datum, 0, len(n.List)) 406 for _, ei := range n.List { 407 ds = append(ds, *ei.GetDatum()) 408 } 409 x := e.checkInList(n.Not, lhs, ds) 410 if e.err != nil { 411 return false 412 } 413 n.SetDatum(x) 414 return true 415 } 416 res := n.Sel.GetDatum().GetRow() 417 x := e.checkInList(n.Not, lhs, res) 418 if e.err != nil { 419 return false 420 } 421 n.SetDatum(x) 422 return true 423 } 424 425 func (e *Evaluator) isNull(v *ast.IsNullExpr) bool { 426 var boolVal bool 427 if v.Expr.GetDatum().Kind() == types.KindNull { 428 boolVal = true 429 } 430 if v.Not { 431 boolVal = !boolVal 432 } 433 v.SetInt64(boolToInt64(boolVal)) 434 return true 435 } 436 437 func (e *Evaluator) isTruth(v *ast.IsTruthExpr) bool { 438 var boolVal bool 439 datum := v.Expr.GetDatum() 440 if datum.Kind() != types.KindNull { 441 ival, err := datum.ToBool() 442 if err != nil { 443 e.err = errors.Trace(err) 444 return false 445 } 446 if ival == v.True { 447 boolVal = true 448 } 449 } 450 if v.Not { 451 boolVal = !boolVal 452 } 453 v.GetDatum().SetInt64(boolToInt64(boolVal)) 454 return true 455 } 456 457 func (e *Evaluator) paramMarker(v *ast.ParamMarkerExpr) bool { 458 return true 459 } 460 461 func (e *Evaluator) parentheses(v *ast.ParenthesesExpr) bool { 462 v.SetDatum(*v.Expr.GetDatum()) 463 return true 464 } 465 466 func (e *Evaluator) position(v *ast.PositionExpr) bool { 467 v.SetDatum(*v.Refer.Expr.GetDatum()) 468 return true 469 } 470 471 func (e *Evaluator) row(v *ast.RowExpr) bool { 472 row := make([]types.Datum, 0, len(v.Values)) 473 for _, val := range v.Values { 474 row = append(row, *val.GetDatum()) 475 } 476 v.GetDatum().SetRow(row) 477 return true 478 } 479 480 func (e *Evaluator) unaryOperation(u *ast.UnaryOperationExpr) bool { 481 defer func() { 482 if er := recover(); er != nil { 483 e.err = errors.Errorf("%v", er) 484 } 485 }() 486 aDatum := u.V.GetDatum() 487 if aDatum.Kind() == types.KindNull { 488 u.SetNull() 489 return true 490 } 491 switch op := u.Op; op { 492 case opcode.Not: 493 n, err := aDatum.ToBool() 494 if err != nil { 495 e.err = errors.Trace(err) 496 } else if n == 0 { 497 u.SetInt64(1) 498 } else { 499 u.SetInt64(0) 500 } 501 case opcode.BitNeg: 502 // for bit operation, we will use int64 first, then return uint64 503 n, err := aDatum.ToInt64() 504 if err != nil { 505 e.err = errors.Trace(err) 506 return false 507 } 508 u.SetUint64(uint64(^n)) 509 case opcode.Plus: 510 switch aDatum.Kind() { 511 case types.KindInt64, 512 types.KindUint64, 513 types.KindFloat64, 514 types.KindFloat32, 515 types.KindMysqlDuration, 516 types.KindMysqlTime, 517 types.KindString, 518 types.KindMysqlDecimal, 519 types.KindBytes, 520 types.KindMysqlHex, 521 types.KindMysqlBit, 522 types.KindMysqlEnum, 523 types.KindMysqlSet: 524 u.SetDatum(*aDatum) 525 default: 526 e.err = ErrInvalidOperation 527 return false 528 } 529 case opcode.Minus: 530 switch aDatum.Kind() { 531 case types.KindInt64: 532 u.SetInt64(-aDatum.GetInt64()) 533 case types.KindUint64: 534 u.SetInt64(-int64(aDatum.GetUint64())) 535 case types.KindFloat64: 536 u.SetFloat64(-aDatum.GetFloat64()) 537 case types.KindFloat32: 538 u.SetFloat32(-aDatum.GetFloat32()) 539 case types.KindMysqlDuration: 540 u.SetMysqlDecimal(mysql.ZeroDecimal.Sub(aDatum.GetMysqlDuration().ToNumber())) 541 case types.KindMysqlTime: 542 u.SetMysqlDecimal(mysql.ZeroDecimal.Sub(aDatum.GetMysqlTime().ToNumber())) 543 case types.KindString, types.KindBytes: 544 f, err := types.StrToFloat(aDatum.GetString()) 545 e.err = errors.Trace(err) 546 u.SetFloat64(-f) 547 case types.KindMysqlDecimal: 548 f, _ := aDatum.GetMysqlDecimal().Float64() 549 u.SetMysqlDecimal(mysql.NewDecimalFromFloat(-f)) 550 case types.KindMysqlHex: 551 u.SetFloat64(-aDatum.GetMysqlHex().ToNumber()) 552 case types.KindMysqlBit: 553 u.SetFloat64(-aDatum.GetMysqlBit().ToNumber()) 554 case types.KindMysqlEnum: 555 u.SetFloat64(-aDatum.GetMysqlEnum().ToNumber()) 556 case types.KindMysqlSet: 557 u.SetFloat64(-aDatum.GetMysqlSet().ToNumber()) 558 default: 559 e.err = ErrInvalidOperation 560 return false 561 } 562 default: 563 e.err = ErrInvalidOperation 564 return false 565 } 566 567 return true 568 } 569 570 func (e *Evaluator) values(v *ast.ValuesExpr) bool { 571 v.SetDatum(*v.Column.GetDatum()) 572 return true 573 } 574 575 func (e *Evaluator) variable(v *ast.VariableExpr) bool { 576 name := strings.ToLower(v.Name) 577 sessionVars := variable.GetSessionVars(e.ctx) 578 globalVars := variable.GetGlobalVarAccessor(e.ctx) 579 if !v.IsSystem { 580 if v.Value != nil && v.Value.GetDatum().Kind() != types.KindNull { 581 strVal, err := v.Value.GetDatum().ToString() 582 if err != nil { 583 e.err = errors.Trace(err) 584 return false 585 } 586 sessionVars.Users[name] = strings.ToLower(strVal) 587 v.SetString(strVal) 588 return true 589 } 590 // user vars 591 if value, ok := sessionVars.Users[name]; ok { 592 v.SetString(value) 593 return true 594 } 595 // select null user vars is permitted. 596 v.SetNull() 597 return true 598 } 599 600 _, ok := variable.SysVars[name] 601 if !ok { 602 // select null sys vars is not permitted 603 e.err = variable.UnknownSystemVar.Gen("Unknown system variable '%s'", name) 604 return false 605 } 606 607 if !v.IsGlobal { 608 if value, ok := sessionVars.Systems[name]; ok { 609 v.SetString(value) 610 return true 611 } 612 } 613 614 value, err := globalVars.GetGlobalSysVar(e.ctx, name) 615 if err != nil { 616 e.err = errors.Trace(err) 617 return false 618 } 619 620 v.SetString(value) 621 return true 622 } 623 624 func (e *Evaluator) funcCall(v *ast.FuncCallExpr) bool { 625 f, ok := Funcs[v.FnName.L] 626 if !ok { 627 e.err = ErrInvalidOperation.Gen("unknown function %s", v.FnName.O) 628 return false 629 } 630 if len(v.Args) < f.MinArgs || (f.MaxArgs != -1 && len(v.Args) > f.MaxArgs) { 631 e.err = ErrInvalidOperation.Gen("number of function arguments must in [%d, %d].", f.MinArgs, f.MaxArgs) 632 return false 633 } 634 a := make([]types.Datum, len(v.Args)) 635 for i, arg := range v.Args { 636 a[i] = *arg.GetDatum() 637 } 638 val, err := f.F(a, e.ctx) 639 if err != nil { 640 e.err = errors.Trace(err) 641 return false 642 } 643 v.SetDatum(val) 644 return true 645 } 646 647 func (e *Evaluator) funcCast(v *ast.FuncCastExpr) bool { 648 d := *v.Expr.GetDatum() 649 // Casting nil to any type returns null 650 if d.Kind() == types.KindNull { 651 v.SetNull() 652 return true 653 } 654 var err error 655 d, err = d.Cast(v.Tp) 656 if err != nil { 657 e.err = errors.Trace(err) 658 return false 659 } 660 v.SetDatum(d) 661 return true 662 } 663 664 func (e *Evaluator) aggregateFunc(v *ast.AggregateFuncExpr) bool { 665 name := strings.ToLower(v.F) 666 switch name { 667 case ast.AggFuncAvg: 668 e.evalAggAvg(v) 669 case ast.AggFuncCount: 670 e.evalAggCount(v) 671 case ast.AggFuncFirstRow, ast.AggFuncMax, ast.AggFuncMin, ast.AggFuncSum: 672 e.evalAggSetValue(v) 673 case ast.AggFuncGroupConcat: 674 e.evalAggGroupConcat(v) 675 } 676 return e.err == nil 677 } 678 679 func (e *Evaluator) evalAggCount(v *ast.AggregateFuncExpr) { 680 ctx := v.GetContext() 681 v.SetInt64(ctx.Count) 682 } 683 684 func (e *Evaluator) evalAggSetValue(v *ast.AggregateFuncExpr) { 685 ctx := v.GetContext() 686 v.SetValue(ctx.Value) 687 } 688 689 func (e *Evaluator) evalAggAvg(v *ast.AggregateFuncExpr) { 690 ctx := v.GetContext() 691 switch x := ctx.Value.(type) { 692 case float64: 693 t := x / float64(ctx.Count) 694 ctx.Value = t 695 v.SetFloat64(t) 696 case mysql.Decimal: 697 t := x.Div(mysql.NewDecimalFromUint(uint64(ctx.Count), 0)) 698 ctx.Value = t 699 v.SetMysqlDecimal(t) 700 } 701 } 702 703 func (e *Evaluator) evalAggGroupConcat(v *ast.AggregateFuncExpr) { 704 ctx := v.GetContext() 705 if ctx.Buffer != nil { 706 v.SetString(ctx.Buffer.String()) 707 } else { 708 v.SetNull() 709 } 710 }