vitess.io/vitess@v0.16.2/go/vt/vtgate/evalengine/func.go (about) 1 /* 2 Copyright 2021 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package evalengine 18 19 import ( 20 "bytes" 21 "fmt" 22 "math" 23 "math/bits" 24 25 "vitess.io/vitess/go/mysql/collations" 26 "vitess.io/vitess/go/sqltypes" 27 vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" 28 "vitess.io/vitess/go/vt/sqlparser" 29 "vitess.io/vitess/go/vt/vterrors" 30 ) 31 32 var builtinFunctions = map[string]builtin{ 33 "coalesce": builtinCoalesce{}, 34 "greatest": &builtinMultiComparison{name: "GREATEST", cmp: 1}, 35 "least": &builtinMultiComparison{name: "LEAST", cmp: -1}, 36 "collation": builtinCollation{}, 37 "bit_count": builtinBitCount{}, 38 "hex": builtinHex{}, 39 "ceil": builtinCeil{}, 40 "ceiling": builtinCeiling{}, 41 "lower": builtinLower{}, 42 "lcase": builtinLcase{}, 43 "upper": builtinUpper{}, 44 "ucase": builtinUcase{}, 45 "char_length": builtinCharLength{}, 46 "character_length": builtinCharacterLength{}, 47 "length": builtinLength{}, 48 "octet_length": builtinOctetLength{}, 49 "bit_length": builtinBitLength{}, 50 "ascii": builtinASCII{}, 51 "repeat": builtinRepeat{}, 52 } 53 54 var builtinFunctionsRewrite = map[string]builtinRewrite{ 55 "isnull": builtinIsNullRewrite, 56 "ifnull": builtinIfNullRewrite, 57 "nullif": builtinNullIfRewrite, 58 } 59 60 type builtin interface { 61 call(*ExpressionEnv, []EvalResult, *EvalResult) 62 typeof(*ExpressionEnv, []Expr) (sqltypes.Type, flag) 63 } 64 65 type builtinRewrite func([]Expr, TranslationLookup) (Expr, error) 66 67 type CallExpr struct { 68 Arguments TupleExpr 69 Aliases []sqlparser.IdentifierCI 70 Method string 71 F builtin 72 } 73 74 func (c *CallExpr) typeof(env *ExpressionEnv) (sqltypes.Type, flag) { 75 return c.F.typeof(env, c.Arguments) 76 } 77 78 func (c *CallExpr) eval(env *ExpressionEnv, result *EvalResult) { 79 var args = make([]EvalResult, len(c.Arguments)) 80 for i, arg := range c.Arguments { 81 args[i].init(env, arg) 82 } 83 c.F.call(env, args, result) 84 } 85 86 type builtinCoalesce struct{} 87 88 func (builtinCoalesce) call(_ *ExpressionEnv, args []EvalResult, result *EvalResult) { 89 for _, arg := range args { 90 if !arg.isNull() { 91 *result = arg 92 result.resolve() 93 return 94 } 95 } 96 result.setNull() 97 } 98 99 func (builtinCoalesce) typeof(env *ExpressionEnv, args []Expr) (sqltypes.Type, flag) { 100 var ta typeAggregation 101 for _, arg := range args { 102 tt, f := arg.typeof(env) 103 ta.add(tt, f) 104 } 105 return ta.result(), flagNullable 106 } 107 108 type multiComparisonFunc func(args []EvalResult, result *EvalResult, cmp int) 109 110 func getMultiComparisonFunc(args []EvalResult) multiComparisonFunc { 111 var ( 112 integers int 113 floats int 114 decimals int 115 text int 116 binary int 117 ) 118 119 /* 120 If any argument is NULL, the result is NULL. No comparison is needed. 121 If all arguments are integer-valued, they are compared as integers. 122 If at least one argument is double precision, they are compared as double-precision values. Otherwise, if at least one argument is a DECIMAL value, they are compared as DECIMAL values. 123 If the arguments comprise a mix of numbers and strings, they are compared as strings. 124 If any argument is a nonbinary (character) string, the arguments are compared as nonbinary strings. 125 In all other cases, the arguments are compared as binary strings. 126 */ 127 128 for i := range args { 129 arg := &args[i] 130 if arg.isNull() { 131 return func(args []EvalResult, result *EvalResult, cmp int) { 132 result.setNull() 133 } 134 } 135 136 switch arg.typeof() { 137 case sqltypes.Int8, sqltypes.Int16, sqltypes.Int32, sqltypes.Int64: 138 integers++ 139 case sqltypes.Uint8, sqltypes.Uint16, sqltypes.Uint32, sqltypes.Uint64: 140 if arg.uint64() > math.MaxInt64 { 141 decimals++ 142 } else { 143 integers++ 144 } 145 case sqltypes.Float32, sqltypes.Float64: 146 floats++ 147 case sqltypes.Decimal: 148 decimals++ 149 case sqltypes.Text, sqltypes.VarChar: 150 text++ 151 case sqltypes.Blob, sqltypes.Binary, sqltypes.VarBinary: 152 binary++ 153 } 154 } 155 156 if integers == len(args) { 157 return compareAllInteger 158 } 159 if binary > 0 || text > 0 { 160 if binary > 0 { 161 return compareAllBinary 162 } 163 if text > 0 { 164 return compareAllText 165 } 166 } else { 167 if floats > 0 { 168 return compareAllFloat 169 } 170 if decimals > 0 { 171 return compareAllDecimal 172 } 173 } 174 panic("unexpected argument type") 175 } 176 177 func compareAllInteger(args []EvalResult, result *EvalResult, cmp int) { 178 var candidateI = args[0].int64() 179 for _, arg := range args[1:] { 180 thisI := arg.int64() 181 if (cmp < 0) == (thisI < candidateI) { 182 candidateI = thisI 183 } 184 } 185 result.setInt64(candidateI) 186 } 187 188 func compareAllFloat(args []EvalResult, result *EvalResult, cmp int) { 189 candidateF, err := args[0].coerceToFloat() 190 if err != nil { 191 throwEvalError(err) 192 } 193 194 for _, arg := range args[1:] { 195 thisF, err := arg.coerceToFloat() 196 if err != nil { 197 throwEvalError(err) 198 } 199 if (cmp < 0) == (thisF < candidateF) { 200 candidateF = thisF 201 } 202 } 203 result.setFloat(candidateF) 204 } 205 206 func compareAllDecimal(args []EvalResult, result *EvalResult, cmp int) { 207 candidateD := args[0].coerceToDecimal() 208 maxFrac := args[0].length_ 209 210 for _, arg := range args[1:] { 211 thisD := arg.coerceToDecimal() 212 if (cmp < 0) == (thisD.Cmp(candidateD) < 0) { 213 candidateD = thisD 214 } 215 if arg.length_ > maxFrac { 216 maxFrac = arg.length_ 217 } 218 } 219 220 result.setDecimal(candidateD, maxFrac) 221 } 222 223 func compareAllText(args []EvalResult, result *EvalResult, cmp int) { 224 env := collations.Local() 225 candidateB := args[0].toRawBytes() 226 collationB := args[0].collation() 227 228 var ca collationAggregation 229 ca.add(env, collationB) 230 231 for _, arg := range args[1:] { 232 thisB := arg.toRawBytes() 233 thisColl := arg.collation() 234 ca.add(env, thisColl) 235 236 thisTC, coerceLeft, coerceRight, err := env.MergeCollations(thisColl, collationB, collations.CoercionOptions{ConvertToSuperset: true, ConvertWithCoercion: true}) 237 if err != nil { 238 throwEvalError(err) 239 } 240 241 collation := env.LookupByID(thisTC.Collation) 242 243 var leftB = thisB 244 var rightB = candidateB 245 if coerceLeft != nil { 246 leftB, _ = coerceLeft(nil, leftB) 247 } 248 if coerceRight != nil { 249 rightB, _ = coerceRight(nil, rightB) 250 } 251 if (cmp < 0) == (collation.Collate(leftB, rightB, false) < 0) { 252 candidateB = thisB 253 } 254 } 255 256 result.setRaw(sqltypes.VarChar, candidateB, ca.result()) 257 } 258 259 type collationAggregation struct { 260 cur collations.TypedCollation 261 init bool 262 } 263 264 func (ca *collationAggregation) add(env *collations.Environment, tc collations.TypedCollation) { 265 if !ca.init { 266 ca.cur = tc 267 ca.init = true 268 } else { 269 var err error 270 ca.cur, _, _, err = env.MergeCollations(ca.cur, tc, collations.CoercionOptions{ConvertToSuperset: true, ConvertWithCoercion: true}) 271 if err != nil { 272 throwEvalError(err) 273 } 274 } 275 } 276 277 func (ca *collationAggregation) result() collations.TypedCollation { 278 return ca.cur 279 } 280 281 func compareAllBinary(args []EvalResult, result *EvalResult, cmp int) { 282 candidateB := args[0].toRawBytes() 283 284 for _, arg := range args[1:] { 285 thisB := arg.toRawBytes() 286 if (cmp < 0) == (bytes.Compare(thisB, candidateB) < 0) { 287 candidateB = thisB 288 } 289 } 290 291 result.setRaw(sqltypes.VarBinary, candidateB, collationBinary) 292 } 293 294 type argError string 295 296 func (err argError) Error() string { 297 return fmt.Sprintf("Incorrect parameter count in the call to native function '%s'", string(err)) 298 } 299 300 func throwArgError(fname string) { 301 panic(evalError{argError(fname)}) 302 } 303 304 type builtinMultiComparison struct { 305 name string 306 cmp int 307 } 308 309 func (cmp *builtinMultiComparison) call(_ *ExpressionEnv, args []EvalResult, result *EvalResult) { 310 getMultiComparisonFunc(args)(args, result, cmp.cmp) 311 } 312 313 func (cmp *builtinMultiComparison) typeof(env *ExpressionEnv, args []Expr) (sqltypes.Type, flag) { 314 if len(args) < 2 { 315 throwArgError(cmp.name) 316 } 317 318 var ( 319 integers int 320 floats int 321 decimals int 322 text int 323 binary int 324 flags flag 325 ) 326 327 for _, expr := range args { 328 tt, f := expr.typeof(env) 329 flags |= f 330 331 switch tt { 332 case sqltypes.Int8, sqltypes.Int16, sqltypes.Int32, sqltypes.Int64: 333 integers++ 334 case sqltypes.Uint8, sqltypes.Uint16, sqltypes.Uint32, sqltypes.Uint64: 335 if f&flagIntegerOvf != 0 { 336 decimals++ 337 } else { 338 integers++ 339 } 340 case sqltypes.Float32, sqltypes.Float64: 341 floats++ 342 case sqltypes.Decimal: 343 decimals++ 344 case sqltypes.Text, sqltypes.VarChar: 345 text++ 346 case sqltypes.Blob, sqltypes.Binary, sqltypes.VarBinary: 347 binary++ 348 } 349 } 350 351 if flags&flagNull != 0 { 352 return sqltypes.Null, flags 353 } 354 if integers == len(args) { 355 return sqltypes.Int64, flags 356 } 357 if binary > 0 || text > 0 { 358 if binary > 0 { 359 return sqltypes.VarBinary, flags 360 } 361 if text > 0 { 362 return sqltypes.VarChar, flags 363 } 364 } else { 365 if floats > 0 { 366 return sqltypes.Float64, flags 367 } 368 if decimals > 0 { 369 return sqltypes.Decimal, flags 370 } 371 } 372 panic("unexpected argument type") 373 } 374 375 type builtinCollation struct{} 376 377 func (builtinCollation) call(env *ExpressionEnv, args []EvalResult, result *EvalResult) { 378 coll := collations.Local().LookupByID(args[0].collation().Collation) 379 380 // the collation of a `COLLATION` expr is hardcoded to `utf8_general_ci`, 381 // not to the default collation of our connection. this is probably a bug in MySQL, but we match it 382 result.setString(coll.Name(), collations.TypedCollation{ 383 Collation: collations.CollationUtf8ID, 384 Coercibility: collations.CoerceImplicit, 385 Repertoire: collations.RepertoireASCII, 386 }) 387 } 388 389 func (builtinCollation) typeof(_ *ExpressionEnv, args []Expr) (sqltypes.Type, flag) { 390 if len(args) != 1 { 391 throwArgError("COLLATION") 392 } 393 return sqltypes.VarChar, 0 394 } 395 396 func builtinIsNullRewrite(args []Expr, _ TranslationLookup) (Expr, error) { 397 if len(args) != 1 { 398 return nil, argError("ISNULL") 399 } 400 return &IsExpr{ 401 UnaryExpr: UnaryExpr{args[0]}, 402 Op: sqlparser.IsNullOp, 403 Check: func(er *EvalResult) bool { return er.isNull() }, 404 }, nil 405 } 406 407 func builtinIfNullRewrite(args []Expr, _ TranslationLookup) (Expr, error) { 408 if len(args) != 2 { 409 return nil, argError("IFNULL") 410 } 411 var result CaseExpr 412 result.cases = append(result.cases, WhenThen{ 413 when: &IsExpr{ 414 UnaryExpr: UnaryExpr{args[0]}, 415 Op: sqlparser.IsNullOp, 416 Check: func(er *EvalResult) bool { return er.isNull() }, 417 }, 418 then: args[1], 419 }) 420 result.Else = args[0] 421 return &result, nil 422 } 423 424 func builtinNullIfRewrite(args []Expr, _ TranslationLookup) (Expr, error) { 425 if len(args) != 2 { 426 return nil, argError("NULLIF") 427 } 428 var result CaseExpr 429 result.cases = append(result.cases, WhenThen{ 430 when: &ComparisonExpr{ 431 BinaryExpr: BinaryExpr{ 432 Left: args[0], 433 Right: args[1], 434 }, 435 Op: compareEQ{}, 436 }, 437 then: NullExpr, 438 }) 439 result.Else = args[0] 440 return &result, nil 441 } 442 443 type builtinBitCount struct{} 444 445 func (builtinBitCount) call(_ *ExpressionEnv, args []EvalResult, result *EvalResult) { 446 var count int 447 inarg := &args[0] 448 449 if inarg.isNull() { 450 result.setNull() 451 return 452 } 453 454 if inarg.isBitwiseBinaryString() { 455 binary := inarg.bytes() 456 for _, b := range binary { 457 count += bits.OnesCount8(b) 458 } 459 } else { 460 inarg.makeUnsignedIntegral() 461 count = bits.OnesCount64(inarg.uint64()) 462 } 463 464 result.setInt64(int64(count)) 465 } 466 467 func (builtinBitCount) typeof(env *ExpressionEnv, args []Expr) (sqltypes.Type, flag) { 468 if len(args) != 1 { 469 throwArgError("BIT_COUNT") 470 } 471 472 _, f := args[0].typeof(env) 473 return sqltypes.Int64, f 474 } 475 476 type WeightStringCallExpr struct { 477 String Expr 478 Cast string 479 Len int 480 HasLen bool 481 } 482 483 func (c *WeightStringCallExpr) typeof(env *ExpressionEnv) (sqltypes.Type, flag) { 484 _, f := c.String.typeof(env) 485 return sqltypes.VarBinary, f 486 } 487 488 func (c *WeightStringCallExpr) eval(env *ExpressionEnv, result *EvalResult) { 489 var ( 490 str EvalResult 491 tc collations.TypedCollation 492 text []byte 493 weights []byte 494 length = c.Len 495 ) 496 497 str.init(env, c.String) 498 tt := str.typeof() 499 500 switch { 501 case sqltypes.IsIntegral(tt): 502 // when calling WEIGHT_STRING with an integral value, MySQL returns the 503 // internal sort key that would be used in an InnoDB table... we do not 504 // support that 505 throwEvalError(vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "%s: %s", ErrEvaluatedExprNotSupported, FormatExpr(c))) 506 case sqltypes.IsQuoted(tt): 507 text = str.bytes() 508 tc = str.collation() 509 default: 510 result.setNull() 511 return 512 } 513 514 if c.Cast == "binary" { 515 tc = collationBinary 516 weights = make([]byte, 0, c.Len) 517 length = collations.PadToMax 518 } 519 520 collation := collations.Local().LookupByID(tc.Collation) 521 weights = collation.WeightString(weights, text, length) 522 result.setRaw(sqltypes.VarBinary, weights, collationBinary) 523 } 524 525 type typeAggregation struct { 526 double uint16 527 decimal uint16 528 signed uint16 529 unsigned uint16 530 531 signedMax sqltypes.Type 532 unsignedMax sqltypes.Type 533 534 bit uint16 535 year uint16 536 char uint16 537 binary uint16 538 charother uint16 539 json uint16 540 541 date uint16 542 time uint16 543 timestamp uint16 544 datetime uint16 545 546 geometry uint16 547 blob uint16 548 total uint16 549 } 550 551 func (ta *typeAggregation) add(tt sqltypes.Type, f flag) { 552 switch tt { 553 case sqltypes.Float32, sqltypes.Float64: 554 ta.double++ 555 case sqltypes.Decimal: 556 ta.decimal++ 557 case sqltypes.Int8, sqltypes.Int16, sqltypes.Int24, sqltypes.Int32, sqltypes.Int64: 558 ta.signed++ 559 if tt > ta.signedMax { 560 ta.signedMax = tt 561 } 562 case sqltypes.Uint8, sqltypes.Uint16, sqltypes.Uint24, sqltypes.Uint32, sqltypes.Uint64: 563 ta.unsigned++ 564 if tt > ta.unsignedMax { 565 ta.unsignedMax = tt 566 } 567 case sqltypes.Bit: 568 ta.bit++ 569 case sqltypes.Year: 570 ta.year++ 571 case sqltypes.Char, sqltypes.VarChar, sqltypes.Set, sqltypes.Enum: 572 if f&flagExplicitCollation != 0 { 573 ta.charother++ 574 } 575 ta.char++ 576 case sqltypes.Binary, sqltypes.VarBinary: 577 if f&flagHex != 0 { 578 ta.charother++ 579 } 580 ta.binary++ 581 case sqltypes.TypeJSON: 582 ta.json++ 583 case sqltypes.Date: 584 ta.date++ 585 case sqltypes.Datetime: 586 ta.datetime++ 587 case sqltypes.Time: 588 ta.time++ 589 case sqltypes.Timestamp: 590 ta.timestamp++ 591 case sqltypes.Geometry: 592 ta.geometry++ 593 case sqltypes.Blob: 594 ta.blob++ 595 default: 596 return 597 } 598 ta.total++ 599 } 600 601 func (ta *typeAggregation) result() sqltypes.Type { 602 /* 603 If all types are numeric, the aggregated type is also numeric: 604 If at least one argument is double precision, the result is double precision. 605 Otherwise, if at least one argument is DECIMAL, the result is DECIMAL. 606 Otherwise, the result is an integer type (with one exception): 607 If all integer types are all signed or all unsigned, the result is the same sign and the precision is the highest of all specified integer types (that is, TINYINT, SMALLINT, MEDIUMINT, INT, or BIGINT). 608 If there is a combination of signed and unsigned integer types, the result is signed and the precision may be higher. For example, if the types are signed INT and unsigned INT, the result is signed BIGINT. 609 The exception is unsigned BIGINT combined with any signed integer type. The result is DECIMAL with sufficient precision and scale 0. 610 If all types are BIT, the result is BIT. Otherwise, BIT arguments are treated similar to BIGINT. 611 If all types are YEAR, the result is YEAR. Otherwise, YEAR arguments are treated similar to INT. 612 If all types are character string (CHAR or VARCHAR), the result is VARCHAR with maximum length determined by the longest character length of the operands. 613 If all types are character or binary string, the result is VARBINARY. 614 SET and ENUM are treated similar to VARCHAR; the result is VARCHAR. 615 If all types are JSON, the result is JSON. 616 If all types are temporal, the result is temporal: 617 If all temporal types are DATE, TIME, or TIMESTAMP, the result is DATE, TIME, or TIMESTAMP, respectively. 618 Otherwise, for a mix of temporal types, the result is DATETIME. 619 If all types are GEOMETRY, the result is GEOMETRY. 620 If any type is BLOB, the result is BLOB. 621 For all other type combinations, the result is VARCHAR. 622 Literal NULL operands are ignored for type aggregation. 623 */ 624 625 if ta.bit == ta.total { 626 return sqltypes.Bit 627 } else if ta.bit > 0 { 628 ta.signed += ta.bit 629 ta.signedMax = sqltypes.Int64 630 } 631 632 if ta.year == ta.total { 633 return sqltypes.Year 634 } else if ta.year > 0 { 635 ta.signed += ta.year 636 if sqltypes.Int32 > ta.signedMax { 637 ta.signedMax = sqltypes.Int32 638 } 639 } 640 641 if ta.double+ta.decimal+ta.signed+ta.unsigned == ta.total { 642 if ta.double > 0 { 643 return sqltypes.Float64 644 } 645 if ta.decimal > 0 { 646 return sqltypes.Decimal 647 } 648 if ta.signed == ta.total { 649 return ta.signedMax 650 } 651 if ta.unsigned == ta.total { 652 return ta.unsignedMax 653 } 654 if ta.unsignedMax == sqltypes.Uint64 && ta.signed > 0 { 655 return sqltypes.Decimal 656 } 657 // TODO 658 return sqltypes.Uint64 659 } 660 661 if ta.char == ta.total { 662 return sqltypes.VarChar 663 } 664 if ta.char+ta.binary == ta.total { 665 // HACK: this is not in the official documentation, but groups of strings where 666 // one of the strings is not directly a VARCHAR or VARBINARY (e.g. a hex literal, 667 // or a VARCHAR that has been explicitly collated) will result in VARCHAR when 668 // aggregated 669 if ta.charother > 0 { 670 return sqltypes.VarChar 671 } 672 return sqltypes.VarBinary 673 } 674 if ta.json == ta.total { 675 return sqltypes.TypeJSON 676 } 677 if ta.date+ta.time+ta.timestamp+ta.datetime == ta.total { 678 if ta.date == ta.total { 679 return sqltypes.Date 680 } 681 if ta.time == ta.total { 682 return sqltypes.Time 683 } 684 if ta.timestamp == ta.total { 685 return sqltypes.Timestamp 686 } 687 return sqltypes.Datetime 688 } 689 if ta.geometry == ta.total { 690 return sqltypes.Geometry 691 } 692 if ta.blob > 0 { 693 return sqltypes.Blob 694 } 695 return sqltypes.VarChar 696 } 697 698 type builtinCeil struct { 699 } 700 701 func (builtinCeil) call(env *ExpressionEnv, args []EvalResult, result *EvalResult) { 702 inarg := &args[0] 703 argtype := inarg.typeof() 704 if inarg.isNull() { 705 result.setNull() 706 return 707 } 708 709 if sqltypes.IsIntegral(argtype) { 710 result.setInt64(inarg.int64()) 711 } else if sqltypes.Decimal == argtype { 712 num := inarg.decimal() 713 num = num.Ceil() 714 intnum, isfit := num.Int64() 715 if isfit { 716 result.setInt64(intnum) 717 } else { 718 result.setDecimal(num, 0) 719 } 720 } else { 721 inarg.makeFloat() 722 result.setFloat(math.Ceil(inarg.float64())) 723 } 724 } 725 726 func (builtinCeil) typeof(env *ExpressionEnv, args []Expr) (sqltypes.Type, flag) { 727 if len(args) != 1 { 728 throwArgError("CEIL") 729 } 730 t, f := args[0].typeof(env) 731 if sqltypes.IsIntegral(t) { 732 return sqltypes.Int64, f 733 } else if sqltypes.Decimal == t { 734 return sqltypes.Decimal, f 735 } else { 736 return sqltypes.Float64, f 737 } 738 } 739 740 type builtinCeiling struct { 741 builtinCeil 742 } 743 744 func (builtinCeiling) typeof(env *ExpressionEnv, args []Expr) (sqltypes.Type, flag) { 745 if len(args) != 1 { 746 throwArgError("CEILING") 747 } 748 t, f := args[0].typeof(env) 749 if sqltypes.IsIntegral(t) { 750 return sqltypes.Int64, f 751 } else if sqltypes.Decimal == t { 752 return sqltypes.Decimal, f 753 } else { 754 return sqltypes.Float64, f 755 } 756 }