github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/string.go (about) 1 // Copyright 2020-2024 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package function 16 17 import ( 18 "encoding/hex" 19 "fmt" 20 "math" 21 "strconv" 22 "strings" 23 "time" 24 "unsafe" 25 26 "github.com/shopspring/decimal" 27 28 "github.com/dolthub/go-mysql-server/sql" 29 "github.com/dolthub/go-mysql-server/sql/encodings" 30 "github.com/dolthub/go-mysql-server/sql/types" 31 ) 32 33 // Ascii implements the sql function "ascii" which returns the numeric value of the leftmost character 34 type Ascii struct { 35 *UnaryFunc 36 } 37 38 var _ sql.FunctionExpression = (*Ascii)(nil) 39 var _ sql.CollationCoercible = (*Ascii)(nil) 40 41 func NewAscii(arg sql.Expression) sql.Expression { 42 return &Ascii{NewUnaryFunc(arg, "ASCII", types.Uint8)} 43 } 44 45 // Description implements sql.FunctionExpression 46 func (a *Ascii) Description() string { 47 return "returns the numeric value of the leftmost character." 48 } 49 50 // CollationCoercibility implements the interface sql.CollationCoercible. 51 func (*Ascii) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { 52 return sql.Collation_binary, 5 53 } 54 55 // Eval implements the sql.Expression interface 56 func (a *Ascii) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { 57 val, err := a.EvalChild(ctx, row) 58 if err != nil { 59 return nil, err 60 } 61 62 if val == nil { 63 return nil, nil 64 } 65 66 str, _, err := types.Text.Convert(val) 67 68 if err != nil { 69 return nil, err 70 } 71 72 s := str.(string) 73 if len(s) == 0 { 74 return uint8(0), nil 75 } 76 77 return s[0], nil 78 } 79 80 // WithChildren implements the sql.Expression interface 81 func (a *Ascii) WithChildren(children ...sql.Expression) (sql.Expression, error) { 82 if len(children) != 1 { 83 return nil, sql.ErrInvalidChildrenNumber.New(a, len(children), 1) 84 } 85 return NewAscii(children[0]), nil 86 } 87 88 // Ord implements the sql function "ord" which returns the numeric value of the leftmost character 89 type Ord struct { 90 *UnaryFunc 91 } 92 93 var _ sql.FunctionExpression = (*Ord)(nil) 94 var _ sql.CollationCoercible = (*Ord)(nil) 95 96 func NewOrd(arg sql.Expression) sql.Expression { 97 return &Ord{NewUnaryFunc(arg, "ORD", types.Int64)} 98 } 99 100 // Description implements sql.FunctionExpression 101 func (o *Ord) Description() string { 102 return "return character code for leftmost character of the argument." 103 } 104 105 // CollationCoercibility implements the interface sql.CollationCoercible. 106 func (o *Ord) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { 107 return sql.Collation_binary, 5 108 } 109 110 // Eval implements the sql.Expression interface 111 func (o *Ord) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { 112 val, err := o.EvalChild(ctx, row) 113 if err != nil { 114 return nil, err 115 } 116 117 if val == nil { 118 return nil, nil 119 } 120 121 str, _, err := types.Text.Convert(val) 122 if err != nil { 123 return nil, err 124 } 125 s := str.(string) 126 if len(s) == 0 { 127 return int64(0), nil 128 } 129 130 // get the leftmost unicode code point as bytes 131 b := []byte(string([]rune(s)[0])) 132 133 // convert into ord 134 var res int64 135 for i, c := range b { 136 res += int64(c) << (8 * (len(b) - 1 - i)) 137 } 138 139 return res, nil 140 } 141 142 // WithChildren implements the sql.Expression interface 143 func (o *Ord) WithChildren(children ...sql.Expression) (sql.Expression, error) { 144 if len(children) != 1 { 145 return nil, sql.ErrInvalidChildrenNumber.New(o, len(children), 1) 146 } 147 return NewOrd(children[0]), nil 148 } 149 150 // Hex implements the sql function "hex" which returns the hexadecimal representation of the string or numeric value 151 type Hex struct { 152 *UnaryFunc 153 } 154 155 var _ sql.FunctionExpression = (*Hex)(nil) 156 var _ sql.CollationCoercible = (*Hex)(nil) 157 158 func NewHex(arg sql.Expression) sql.Expression { 159 // Although this may seem convoluted, the Collation_Default is NOT guaranteed to be the character set's default 160 // collation. This ensures that you're getting the character set's default collation, and also works in the event 161 // that the Collation_Default is ever changed. 162 retType := types.CreateLongText(sql.Collation_Default.CharacterSet().DefaultCollation()) 163 return &Hex{NewUnaryFunc(arg, "HEX", retType)} 164 } 165 166 // Description implements sql.FunctionExpression 167 func (h *Hex) Description() string { 168 return "returns the hexadecimal representation of the string or numeric value." 169 } 170 171 // CollationCoercibility implements the interface sql.CollationCoercible. 172 func (*Hex) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { 173 return ctx.GetCollation(), 4 174 } 175 176 // Eval implements the sql.Expression interface 177 func (h *Hex) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { 178 arg, err := h.EvalChild(ctx, row) 179 if err != nil { 180 return nil, err 181 } 182 183 if arg == nil { 184 return nil, nil 185 } 186 187 switch val := arg.(type) { 188 case string: 189 childType := h.Child.Type() 190 if types.IsTextOnly(childType) { 191 // For string types we need to re-encode the internal string so that we get the correct hex output 192 encoder := childType.(sql.StringType).Collation().CharacterSet().Encoder() 193 encodedBytes, ok := encoder.Encode(encodings.StringToBytes(val)) 194 if !ok { 195 return nil, fmt.Errorf("unable to re-encode string for HEX function") 196 } 197 return hexForString(encodings.BytesToString(encodedBytes)), nil 198 } else { 199 return hexForString(val), nil 200 } 201 202 case uint8, uint16, uint32, uint, int, int8, int16, int32, int64: 203 n, _, err := types.Int64.Convert(arg) 204 205 if err != nil { 206 return nil, err 207 } 208 209 a := n.(int64) 210 if a < 0 { 211 return hexForNegativeInt64(a), nil 212 } else { 213 return fmt.Sprintf("%X", a), nil 214 } 215 216 case uint64: 217 return fmt.Sprintf("%X", val), nil 218 219 case float32: 220 return hexForFloat(float64(val)) 221 222 case float64: 223 return hexForFloat(val) 224 225 case decimal.Decimal: 226 f, _ := val.Float64() 227 return hexForFloat(f) 228 229 case bool: 230 if val { 231 return "1", nil 232 } 233 234 return "0", nil 235 236 case time.Time: 237 s, err := formatDate("%Y-%m-%d %H:%i:%s", val) 238 239 if err != nil { 240 return nil, err 241 } 242 243 s += fractionOfSecString(val) 244 245 return hexForString(s), nil 246 247 case []byte: 248 return hexForString(string(val)), nil 249 250 case types.GeometryValue: 251 return hexForString(string(val.Serialize())), nil 252 253 default: 254 return nil, sql.ErrInvalidArgumentDetails.New("hex", fmt.Sprint(arg)) 255 } 256 } 257 258 // WithChildren implements the sql.Expression interface 259 func (h *Hex) WithChildren(children ...sql.Expression) (sql.Expression, error) { 260 if len(children) != 1 { 261 return nil, sql.ErrInvalidChildrenNumber.New(h, len(children), 1) 262 } 263 return NewHex(children[0]), nil 264 } 265 266 func hexChar(b byte) byte { 267 if b > 9 { 268 return b - 10 + byte('A') 269 } 270 271 return b + byte('0') 272 } 273 274 // MySQL expects the 64 bit 2s complement representation for negative integer values. Typical methods for converting a 275 // number to a string don't handle negative integer values in this way (strconv.FormatInt and fmt.Sprintf for example). 276 func hexForNegativeInt64(n int64) string { 277 // get a pointer to the int64s memory 278 mem := (*[8]byte)(unsafe.Pointer(&n)) 279 // make a copy of the data that I can manipulate 280 bytes := *mem 281 // reverse the order for printing 282 for i := 0; i < 4; i++ { 283 bytes[i], bytes[7-i] = bytes[7-i], bytes[i] 284 } 285 // print the hex encoded bytes 286 return fmt.Sprintf("%X", bytes) 287 } 288 289 func hexForFloat(f float64) (string, error) { 290 if f < 0 { 291 f -= 0.5 292 n := int64(f) 293 return hexForNegativeInt64(n), nil 294 } 295 296 f += 0.5 297 n := uint64(f) 298 return fmt.Sprintf("%X", n), nil 299 } 300 301 func hexForString(val string) string { 302 buf := make([]byte, 0, 2*len(val)) 303 // Do not change this to range, as range iterates over runes and not bytes 304 for i := 0; i < len(val); i++ { 305 c := val[i] 306 high := c / 16 307 low := c % 16 308 309 buf = append(buf, hexChar(high)) 310 buf = append(buf, hexChar(low)) 311 } 312 return string(buf) 313 } 314 315 // Unhex implements the sql function "unhex" which returns the integer representation of a hexadecimal string 316 type Unhex struct { 317 *UnaryFunc 318 } 319 320 var _ sql.FunctionExpression = (*Unhex)(nil) 321 var _ sql.CollationCoercible = (*Unhex)(nil) 322 323 func NewUnhex(arg sql.Expression) sql.Expression { 324 return &Unhex{NewUnaryFunc(arg, "UNHEX", types.LongBlob)} 325 } 326 327 // Description implements sql.FunctionExpression 328 func (h *Unhex) Description() string { 329 return "returns a string containing hex representation of a number." 330 } 331 332 // CollationCoercibility implements the interface sql.CollationCoercible. 333 func (*Unhex) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { 334 return sql.Collation_binary, 4 335 } 336 337 // Eval implements the sql.Expression interface 338 func (h *Unhex) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { 339 arg, err := h.EvalChild(ctx, row) 340 if err != nil { 341 return nil, err 342 } 343 344 if arg == nil { 345 return nil, nil 346 } 347 348 val, _, err := types.LongText.Convert(arg) 349 350 if err != nil { 351 return nil, err 352 } 353 354 s := val.(string) 355 if len(s)%2 != 0 { 356 s = "0" + s 357 } 358 359 s = strings.ToUpper(s) 360 for _, c := range s { 361 if c < '0' || c > '9' && c < 'A' || c > 'F' { 362 return nil, nil 363 } 364 } 365 366 res, err := hex.DecodeString(s) 367 368 if err != nil { 369 return nil, err 370 } 371 372 return res, nil 373 } 374 375 // WithChildren implements the sql.Expression interface 376 func (h *Unhex) WithChildren(children ...sql.Expression) (sql.Expression, error) { 377 if len(children) != 1 { 378 return nil, sql.ErrInvalidChildrenNumber.New(h, len(children), 1) 379 } 380 return NewUnhex(children[0]), nil 381 } 382 383 // MySQL expects the 64 bit 2s complement representation for negative integer values. Typical methods for converting a 384 // number to a string don't handle negative integer values in this way (strconv.FormatInt and fmt.Sprintf for example). 385 func binForNegativeInt64(n int64) string { 386 // get a pointer to the int64s memory 387 mem := (*[8]byte)(unsafe.Pointer(&n)) 388 // make a copy of the data that I can manipulate 389 bytes := *mem 390 391 s := "" 392 for i := 7; i >= 0; i-- { 393 s += strconv.FormatInt(int64(bytes[i]), 2) 394 } 395 396 return s 397 } 398 399 // Bin implements the sql function "bin" which returns the binary representation of a number 400 type Bin struct { 401 *UnaryFunc 402 } 403 404 var _ sql.FunctionExpression = (*Bin)(nil) 405 var _ sql.CollationCoercible = (*Bin)(nil) 406 407 func NewBin(arg sql.Expression) sql.Expression { 408 return &Bin{NewUnaryFunc(arg, "BIN", types.Text)} 409 } 410 411 // FunctionName implements sql.FunctionExpression 412 func (b *Bin) FunctionName() string { 413 return "bin" 414 } 415 416 // Description implements sql.FunctionExpression 417 func (b *Bin) Description() string { 418 return "returns the binary representation of a number." 419 } 420 421 // CollationCoercibility implements the interface sql.CollationCoercible. 422 func (*Bin) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { 423 return ctx.GetCollation(), 4 424 } 425 426 // Eval implements the sql.Expression interface 427 func (h *Bin) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { 428 arg, err := h.EvalChild(ctx, row) 429 if err != nil { 430 return nil, err 431 } 432 433 if arg == nil { 434 return nil, nil 435 } 436 437 switch val := arg.(type) { 438 case time.Time: 439 return strconv.FormatUint(uint64(val.Year()), 2), nil 440 case uint64: 441 return strconv.FormatUint(val, 2), nil 442 443 default: 444 n, err := h.convertToInt64(arg) 445 446 if err != nil { 447 return "0", nil 448 } 449 450 if n < 0 { 451 return binForNegativeInt64(n), nil 452 } else { 453 return strconv.FormatInt(n, 2), nil 454 } 455 } 456 } 457 458 // WithChildren implements the sql.Expression interface 459 func (h *Bin) WithChildren(children ...sql.Expression) (sql.Expression, error) { 460 if len(children) != 1 { 461 return nil, sql.ErrInvalidChildrenNumber.New(h, len(children), 1) 462 } 463 return NewBin(children[0]), nil 464 } 465 466 // convertToInt64 handles the conversion from the given interface to an Int64. This mirrors the original behavior of how 467 // sql.Int64 handled conversions, which matches the expected behavior of this function. sql.Int64 has been fixed, 468 // and the fixes cause incorrect behavior for this function (as they use different rules), therefore this is simply to 469 // restore the original behavior specifically for this function. 470 func (h *Bin) convertToInt64(v interface{}) (int64, error) { 471 switch v := v.(type) { 472 case int: 473 return int64(v), nil 474 case int8: 475 return int64(v), nil 476 case int16: 477 return int64(v), nil 478 case int32: 479 return int64(v), nil 480 case int64: 481 return v, nil 482 case uint: 483 return int64(v), nil 484 case uint8: 485 return int64(v), nil 486 case uint16: 487 return int64(v), nil 488 case uint32: 489 return int64(v), nil 490 case uint64: 491 if v > math.MaxInt64 { 492 return math.MaxInt64, nil 493 } 494 return int64(v), nil 495 case float32: 496 if v >= float32(math.MaxInt64) { 497 return math.MaxInt64, nil 498 } else if v <= float32(math.MinInt64) { 499 return math.MinInt64, nil 500 } 501 return int64(v), nil 502 case float64: 503 if v >= float64(math.MaxInt64) { 504 return math.MaxInt64, nil 505 } else if v <= float64(math.MinInt64) { 506 return math.MinInt64, nil 507 } 508 return int64(v), nil 509 case decimal.Decimal: 510 if v.GreaterThan(decimal.NewFromInt(math.MaxInt64)) { 511 return math.MaxInt64, nil 512 } else if v.LessThan(decimal.NewFromInt(math.MinInt64)) { 513 return math.MinInt64, nil 514 } 515 return v.IntPart(), nil 516 case []byte: 517 i, err := strconv.ParseInt(hex.EncodeToString(v), 16, 64) 518 if err != nil { 519 return 0, sql.ErrInvalidValue.New(v, types.Int64.String()) 520 } 521 return i, nil 522 case string: 523 // Parse first an integer, which allows for more values than float64 524 i, err := strconv.ParseInt(v, 10, 64) 525 if err == nil { 526 return i, nil 527 } 528 // If that fails, try as a float and truncate it to integral 529 f, err := strconv.ParseFloat(v, 64) 530 if err != nil { 531 return 0, sql.ErrInvalidValue.New(v, types.Int64.String()) 532 } 533 return int64(f), nil 534 case bool: 535 if v { 536 return 1, nil 537 } 538 return 0, nil 539 case nil: 540 return 0, nil 541 default: 542 return 0, sql.ErrInvalidValueType.New(v, types.Int64.String()) 543 } 544 } 545 546 // Bitlength implements the sql function "bit_length" which returns the data length of the argument in bits 547 type Bitlength struct { 548 *UnaryFunc 549 } 550 551 var _ sql.FunctionExpression = (*Bitlength)(nil) 552 var _ sql.CollationCoercible = (*Bitlength)(nil) 553 554 func NewBitlength(arg sql.Expression) sql.Expression { 555 return &Bitlength{NewUnaryFunc(arg, "BIT_LENGTH", types.Int32)} 556 } 557 558 // FunctionName implements sql.FunctionExpression 559 func (b *Bitlength) FunctionName() string { 560 return "bit_length" 561 } 562 563 // Description implements sql.FunctionExpression 564 func (b *Bitlength) Description() string { 565 return "returns the data length of the argument in bits." 566 } 567 568 // CollationCoercibility implements the interface sql.CollationCoercible. 569 func (*Bitlength) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { 570 return sql.Collation_binary, 5 571 } 572 573 // Eval implements the sql.Expression interface 574 func (h *Bitlength) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { 575 arg, err := h.EvalChild(ctx, row) 576 if err != nil { 577 return nil, err 578 } 579 580 if arg == nil { 581 return nil, nil 582 } 583 584 switch val := arg.(type) { 585 case uint8, int8, bool: 586 return 8, nil 587 case uint16, int16: 588 return 16, nil 589 case int, uint, uint32, int32, float32: 590 return 32, nil 591 case uint64, int64, float64: 592 return 64, nil 593 case string: 594 return 8 * len([]byte(val)), nil 595 case time.Time: 596 return 128, nil 597 } 598 599 return nil, sql.ErrInvalidArgumentDetails.New("bit_length", fmt.Sprint(arg)) 600 } 601 602 // WithChildren implements the sql.Expression interface 603 func (h *Bitlength) WithChildren(children ...sql.Expression) (sql.Expression, error) { 604 if len(children) != 1 { 605 return nil, sql.ErrInvalidChildrenNumber.New(h, len(children), 1) 606 } 607 return NewBitlength(children[0]), nil 608 }