github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/string.go (about)

     1  // Copyright 2020-2024 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package function
    16  
    17  import (
    18  	"encoding/hex"
    19  	"fmt"
    20  	"math"
    21  	"strconv"
    22  	"strings"
    23  	"time"
    24  	"unsafe"
    25  
    26  	"github.com/shopspring/decimal"
    27  
    28  	"github.com/dolthub/go-mysql-server/sql"
    29  	"github.com/dolthub/go-mysql-server/sql/encodings"
    30  	"github.com/dolthub/go-mysql-server/sql/types"
    31  )
    32  
    33  // Ascii implements the sql function "ascii" which returns the numeric value of the leftmost character
    34  type Ascii struct {
    35  	*UnaryFunc
    36  }
    37  
    38  var _ sql.FunctionExpression = (*Ascii)(nil)
    39  var _ sql.CollationCoercible = (*Ascii)(nil)
    40  
    41  func NewAscii(arg sql.Expression) sql.Expression {
    42  	return &Ascii{NewUnaryFunc(arg, "ASCII", types.Uint8)}
    43  }
    44  
    45  // Description implements sql.FunctionExpression
    46  func (a *Ascii) Description() string {
    47  	return "returns the numeric value of the leftmost character."
    48  }
    49  
    50  // CollationCoercibility implements the interface sql.CollationCoercible.
    51  func (*Ascii) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
    52  	return sql.Collation_binary, 5
    53  }
    54  
    55  // Eval implements the sql.Expression interface
    56  func (a *Ascii) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
    57  	val, err := a.EvalChild(ctx, row)
    58  	if err != nil {
    59  		return nil, err
    60  	}
    61  
    62  	if val == nil {
    63  		return nil, nil
    64  	}
    65  
    66  	str, _, err := types.Text.Convert(val)
    67  
    68  	if err != nil {
    69  		return nil, err
    70  	}
    71  
    72  	s := str.(string)
    73  	if len(s) == 0 {
    74  		return uint8(0), nil
    75  	}
    76  
    77  	return s[0], nil
    78  }
    79  
    80  // WithChildren implements the sql.Expression interface
    81  func (a *Ascii) WithChildren(children ...sql.Expression) (sql.Expression, error) {
    82  	if len(children) != 1 {
    83  		return nil, sql.ErrInvalidChildrenNumber.New(a, len(children), 1)
    84  	}
    85  	return NewAscii(children[0]), nil
    86  }
    87  
    88  // Ord implements the sql function "ord" which returns the numeric value of the leftmost character
    89  type Ord struct {
    90  	*UnaryFunc
    91  }
    92  
    93  var _ sql.FunctionExpression = (*Ord)(nil)
    94  var _ sql.CollationCoercible = (*Ord)(nil)
    95  
    96  func NewOrd(arg sql.Expression) sql.Expression {
    97  	return &Ord{NewUnaryFunc(arg, "ORD", types.Int64)}
    98  }
    99  
   100  // Description implements sql.FunctionExpression
   101  func (o *Ord) Description() string {
   102  	return "return character code for leftmost character of the argument."
   103  }
   104  
   105  // CollationCoercibility implements the interface sql.CollationCoercible.
   106  func (o *Ord) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   107  	return sql.Collation_binary, 5
   108  }
   109  
   110  // Eval implements the sql.Expression interface
   111  func (o *Ord) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
   112  	val, err := o.EvalChild(ctx, row)
   113  	if err != nil {
   114  		return nil, err
   115  	}
   116  
   117  	if val == nil {
   118  		return nil, nil
   119  	}
   120  
   121  	str, _, err := types.Text.Convert(val)
   122  	if err != nil {
   123  		return nil, err
   124  	}
   125  	s := str.(string)
   126  	if len(s) == 0 {
   127  		return int64(0), nil
   128  	}
   129  
   130  	// get the leftmost unicode code point as bytes
   131  	b := []byte(string([]rune(s)[0]))
   132  
   133  	// convert into ord
   134  	var res int64
   135  	for i, c := range b {
   136  		res += int64(c) << (8 * (len(b) - 1 - i))
   137  	}
   138  
   139  	return res, nil
   140  }
   141  
   142  // WithChildren implements the sql.Expression interface
   143  func (o *Ord) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   144  	if len(children) != 1 {
   145  		return nil, sql.ErrInvalidChildrenNumber.New(o, len(children), 1)
   146  	}
   147  	return NewOrd(children[0]), nil
   148  }
   149  
   150  // Hex implements the sql function "hex" which returns the hexadecimal representation of the string or numeric value
   151  type Hex struct {
   152  	*UnaryFunc
   153  }
   154  
   155  var _ sql.FunctionExpression = (*Hex)(nil)
   156  var _ sql.CollationCoercible = (*Hex)(nil)
   157  
   158  func NewHex(arg sql.Expression) sql.Expression {
   159  	// Although this may seem convoluted, the Collation_Default is NOT guaranteed to be the character set's default
   160  	// collation. This ensures that you're getting the character set's default collation, and also works in the event
   161  	// that the Collation_Default is ever changed.
   162  	retType := types.CreateLongText(sql.Collation_Default.CharacterSet().DefaultCollation())
   163  	return &Hex{NewUnaryFunc(arg, "HEX", retType)}
   164  }
   165  
   166  // Description implements sql.FunctionExpression
   167  func (h *Hex) Description() string {
   168  	return "returns the hexadecimal representation of the string or numeric value."
   169  }
   170  
   171  // CollationCoercibility implements the interface sql.CollationCoercible.
   172  func (*Hex) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   173  	return ctx.GetCollation(), 4
   174  }
   175  
   176  // Eval implements the sql.Expression interface
   177  func (h *Hex) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
   178  	arg, err := h.EvalChild(ctx, row)
   179  	if err != nil {
   180  		return nil, err
   181  	}
   182  
   183  	if arg == nil {
   184  		return nil, nil
   185  	}
   186  
   187  	switch val := arg.(type) {
   188  	case string:
   189  		childType := h.Child.Type()
   190  		if types.IsTextOnly(childType) {
   191  			// For string types we need to re-encode the internal string so that we get the correct hex output
   192  			encoder := childType.(sql.StringType).Collation().CharacterSet().Encoder()
   193  			encodedBytes, ok := encoder.Encode(encodings.StringToBytes(val))
   194  			if !ok {
   195  				return nil, fmt.Errorf("unable to re-encode string for HEX function")
   196  			}
   197  			return hexForString(encodings.BytesToString(encodedBytes)), nil
   198  		} else {
   199  			return hexForString(val), nil
   200  		}
   201  
   202  	case uint8, uint16, uint32, uint, int, int8, int16, int32, int64:
   203  		n, _, err := types.Int64.Convert(arg)
   204  
   205  		if err != nil {
   206  			return nil, err
   207  		}
   208  
   209  		a := n.(int64)
   210  		if a < 0 {
   211  			return hexForNegativeInt64(a), nil
   212  		} else {
   213  			return fmt.Sprintf("%X", a), nil
   214  		}
   215  
   216  	case uint64:
   217  		return fmt.Sprintf("%X", val), nil
   218  
   219  	case float32:
   220  		return hexForFloat(float64(val))
   221  
   222  	case float64:
   223  		return hexForFloat(val)
   224  
   225  	case decimal.Decimal:
   226  		f, _ := val.Float64()
   227  		return hexForFloat(f)
   228  
   229  	case bool:
   230  		if val {
   231  			return "1", nil
   232  		}
   233  
   234  		return "0", nil
   235  
   236  	case time.Time:
   237  		s, err := formatDate("%Y-%m-%d %H:%i:%s", val)
   238  
   239  		if err != nil {
   240  			return nil, err
   241  		}
   242  
   243  		s += fractionOfSecString(val)
   244  
   245  		return hexForString(s), nil
   246  
   247  	case []byte:
   248  		return hexForString(string(val)), nil
   249  
   250  	case types.GeometryValue:
   251  		return hexForString(string(val.Serialize())), nil
   252  
   253  	default:
   254  		return nil, sql.ErrInvalidArgumentDetails.New("hex", fmt.Sprint(arg))
   255  	}
   256  }
   257  
   258  // WithChildren implements the sql.Expression interface
   259  func (h *Hex) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   260  	if len(children) != 1 {
   261  		return nil, sql.ErrInvalidChildrenNumber.New(h, len(children), 1)
   262  	}
   263  	return NewHex(children[0]), nil
   264  }
   265  
   266  func hexChar(b byte) byte {
   267  	if b > 9 {
   268  		return b - 10 + byte('A')
   269  	}
   270  
   271  	return b + byte('0')
   272  }
   273  
   274  // MySQL expects the 64 bit 2s complement representation for negative integer values. Typical methods for converting a
   275  // number to a string don't handle negative integer values in this way (strconv.FormatInt and fmt.Sprintf for example).
   276  func hexForNegativeInt64(n int64) string {
   277  	// get a pointer to the int64s memory
   278  	mem := (*[8]byte)(unsafe.Pointer(&n))
   279  	// make a copy of the data that I can manipulate
   280  	bytes := *mem
   281  	// reverse the order for printing
   282  	for i := 0; i < 4; i++ {
   283  		bytes[i], bytes[7-i] = bytes[7-i], bytes[i]
   284  	}
   285  	// print the hex encoded bytes
   286  	return fmt.Sprintf("%X", bytes)
   287  }
   288  
   289  func hexForFloat(f float64) (string, error) {
   290  	if f < 0 {
   291  		f -= 0.5
   292  		n := int64(f)
   293  		return hexForNegativeInt64(n), nil
   294  	}
   295  
   296  	f += 0.5
   297  	n := uint64(f)
   298  	return fmt.Sprintf("%X", n), nil
   299  }
   300  
   301  func hexForString(val string) string {
   302  	buf := make([]byte, 0, 2*len(val))
   303  	// Do not change this to range, as range iterates over runes and not bytes
   304  	for i := 0; i < len(val); i++ {
   305  		c := val[i]
   306  		high := c / 16
   307  		low := c % 16
   308  
   309  		buf = append(buf, hexChar(high))
   310  		buf = append(buf, hexChar(low))
   311  	}
   312  	return string(buf)
   313  }
   314  
   315  // Unhex implements the sql function "unhex" which returns the integer representation of a hexadecimal string
   316  type Unhex struct {
   317  	*UnaryFunc
   318  }
   319  
   320  var _ sql.FunctionExpression = (*Unhex)(nil)
   321  var _ sql.CollationCoercible = (*Unhex)(nil)
   322  
   323  func NewUnhex(arg sql.Expression) sql.Expression {
   324  	return &Unhex{NewUnaryFunc(arg, "UNHEX", types.LongBlob)}
   325  }
   326  
   327  // Description implements sql.FunctionExpression
   328  func (h *Unhex) Description() string {
   329  	return "returns a string containing hex representation of a number."
   330  }
   331  
   332  // CollationCoercibility implements the interface sql.CollationCoercible.
   333  func (*Unhex) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   334  	return sql.Collation_binary, 4
   335  }
   336  
   337  // Eval implements the sql.Expression interface
   338  func (h *Unhex) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
   339  	arg, err := h.EvalChild(ctx, row)
   340  	if err != nil {
   341  		return nil, err
   342  	}
   343  
   344  	if arg == nil {
   345  		return nil, nil
   346  	}
   347  
   348  	val, _, err := types.LongText.Convert(arg)
   349  
   350  	if err != nil {
   351  		return nil, err
   352  	}
   353  
   354  	s := val.(string)
   355  	if len(s)%2 != 0 {
   356  		s = "0" + s
   357  	}
   358  
   359  	s = strings.ToUpper(s)
   360  	for _, c := range s {
   361  		if c < '0' || c > '9' && c < 'A' || c > 'F' {
   362  			return nil, nil
   363  		}
   364  	}
   365  
   366  	res, err := hex.DecodeString(s)
   367  
   368  	if err != nil {
   369  		return nil, err
   370  	}
   371  
   372  	return res, nil
   373  }
   374  
   375  // WithChildren implements the sql.Expression interface
   376  func (h *Unhex) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   377  	if len(children) != 1 {
   378  		return nil, sql.ErrInvalidChildrenNumber.New(h, len(children), 1)
   379  	}
   380  	return NewUnhex(children[0]), nil
   381  }
   382  
   383  // MySQL expects the 64 bit 2s complement representation for negative integer values. Typical methods for converting a
   384  // number to a string don't handle negative integer values in this way (strconv.FormatInt and fmt.Sprintf for example).
   385  func binForNegativeInt64(n int64) string {
   386  	// get a pointer to the int64s memory
   387  	mem := (*[8]byte)(unsafe.Pointer(&n))
   388  	// make a copy of the data that I can manipulate
   389  	bytes := *mem
   390  
   391  	s := ""
   392  	for i := 7; i >= 0; i-- {
   393  		s += strconv.FormatInt(int64(bytes[i]), 2)
   394  	}
   395  
   396  	return s
   397  }
   398  
   399  // Bin implements the sql function "bin" which returns the binary representation of a number
   400  type Bin struct {
   401  	*UnaryFunc
   402  }
   403  
   404  var _ sql.FunctionExpression = (*Bin)(nil)
   405  var _ sql.CollationCoercible = (*Bin)(nil)
   406  
   407  func NewBin(arg sql.Expression) sql.Expression {
   408  	return &Bin{NewUnaryFunc(arg, "BIN", types.Text)}
   409  }
   410  
   411  // FunctionName implements sql.FunctionExpression
   412  func (b *Bin) FunctionName() string {
   413  	return "bin"
   414  }
   415  
   416  // Description implements sql.FunctionExpression
   417  func (b *Bin) Description() string {
   418  	return "returns the binary representation of a number."
   419  }
   420  
   421  // CollationCoercibility implements the interface sql.CollationCoercible.
   422  func (*Bin) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   423  	return ctx.GetCollation(), 4
   424  }
   425  
   426  // Eval implements the sql.Expression interface
   427  func (h *Bin) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
   428  	arg, err := h.EvalChild(ctx, row)
   429  	if err != nil {
   430  		return nil, err
   431  	}
   432  
   433  	if arg == nil {
   434  		return nil, nil
   435  	}
   436  
   437  	switch val := arg.(type) {
   438  	case time.Time:
   439  		return strconv.FormatUint(uint64(val.Year()), 2), nil
   440  	case uint64:
   441  		return strconv.FormatUint(val, 2), nil
   442  
   443  	default:
   444  		n, err := h.convertToInt64(arg)
   445  
   446  		if err != nil {
   447  			return "0", nil
   448  		}
   449  
   450  		if n < 0 {
   451  			return binForNegativeInt64(n), nil
   452  		} else {
   453  			return strconv.FormatInt(n, 2), nil
   454  		}
   455  	}
   456  }
   457  
   458  // WithChildren implements the sql.Expression interface
   459  func (h *Bin) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   460  	if len(children) != 1 {
   461  		return nil, sql.ErrInvalidChildrenNumber.New(h, len(children), 1)
   462  	}
   463  	return NewBin(children[0]), nil
   464  }
   465  
   466  // convertToInt64 handles the conversion from the given interface to an Int64. This mirrors the original behavior of how
   467  // sql.Int64 handled conversions, which matches the expected behavior of this function. sql.Int64 has been fixed,
   468  // and the fixes cause incorrect behavior for this function (as they use different rules), therefore this is simply to
   469  // restore the original behavior specifically for this function.
   470  func (h *Bin) convertToInt64(v interface{}) (int64, error) {
   471  	switch v := v.(type) {
   472  	case int:
   473  		return int64(v), nil
   474  	case int8:
   475  		return int64(v), nil
   476  	case int16:
   477  		return int64(v), nil
   478  	case int32:
   479  		return int64(v), nil
   480  	case int64:
   481  		return v, nil
   482  	case uint:
   483  		return int64(v), nil
   484  	case uint8:
   485  		return int64(v), nil
   486  	case uint16:
   487  		return int64(v), nil
   488  	case uint32:
   489  		return int64(v), nil
   490  	case uint64:
   491  		if v > math.MaxInt64 {
   492  			return math.MaxInt64, nil
   493  		}
   494  		return int64(v), nil
   495  	case float32:
   496  		if v >= float32(math.MaxInt64) {
   497  			return math.MaxInt64, nil
   498  		} else if v <= float32(math.MinInt64) {
   499  			return math.MinInt64, nil
   500  		}
   501  		return int64(v), nil
   502  	case float64:
   503  		if v >= float64(math.MaxInt64) {
   504  			return math.MaxInt64, nil
   505  		} else if v <= float64(math.MinInt64) {
   506  			return math.MinInt64, nil
   507  		}
   508  		return int64(v), nil
   509  	case decimal.Decimal:
   510  		if v.GreaterThan(decimal.NewFromInt(math.MaxInt64)) {
   511  			return math.MaxInt64, nil
   512  		} else if v.LessThan(decimal.NewFromInt(math.MinInt64)) {
   513  			return math.MinInt64, nil
   514  		}
   515  		return v.IntPart(), nil
   516  	case []byte:
   517  		i, err := strconv.ParseInt(hex.EncodeToString(v), 16, 64)
   518  		if err != nil {
   519  			return 0, sql.ErrInvalidValue.New(v, types.Int64.String())
   520  		}
   521  		return i, nil
   522  	case string:
   523  		// Parse first an integer, which allows for more values than float64
   524  		i, err := strconv.ParseInt(v, 10, 64)
   525  		if err == nil {
   526  			return i, nil
   527  		}
   528  		// If that fails, try as a float and truncate it to integral
   529  		f, err := strconv.ParseFloat(v, 64)
   530  		if err != nil {
   531  			return 0, sql.ErrInvalidValue.New(v, types.Int64.String())
   532  		}
   533  		return int64(f), nil
   534  	case bool:
   535  		if v {
   536  			return 1, nil
   537  		}
   538  		return 0, nil
   539  	case nil:
   540  		return 0, nil
   541  	default:
   542  		return 0, sql.ErrInvalidValueType.New(v, types.Int64.String())
   543  	}
   544  }
   545  
   546  // Bitlength implements the sql function "bit_length" which returns the data length of the argument in bits
   547  type Bitlength struct {
   548  	*UnaryFunc
   549  }
   550  
   551  var _ sql.FunctionExpression = (*Bitlength)(nil)
   552  var _ sql.CollationCoercible = (*Bitlength)(nil)
   553  
   554  func NewBitlength(arg sql.Expression) sql.Expression {
   555  	return &Bitlength{NewUnaryFunc(arg, "BIT_LENGTH", types.Int32)}
   556  }
   557  
   558  // FunctionName implements sql.FunctionExpression
   559  func (b *Bitlength) FunctionName() string {
   560  	return "bit_length"
   561  }
   562  
   563  // Description implements sql.FunctionExpression
   564  func (b *Bitlength) Description() string {
   565  	return "returns the data length of the argument in bits."
   566  }
   567  
   568  // CollationCoercibility implements the interface sql.CollationCoercible.
   569  func (*Bitlength) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   570  	return sql.Collation_binary, 5
   571  }
   572  
   573  // Eval implements the sql.Expression interface
   574  func (h *Bitlength) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
   575  	arg, err := h.EvalChild(ctx, row)
   576  	if err != nil {
   577  		return nil, err
   578  	}
   579  
   580  	if arg == nil {
   581  		return nil, nil
   582  	}
   583  
   584  	switch val := arg.(type) {
   585  	case uint8, int8, bool:
   586  		return 8, nil
   587  	case uint16, int16:
   588  		return 16, nil
   589  	case int, uint, uint32, int32, float32:
   590  		return 32, nil
   591  	case uint64, int64, float64:
   592  		return 64, nil
   593  	case string:
   594  		return 8 * len([]byte(val)), nil
   595  	case time.Time:
   596  		return 128, nil
   597  	}
   598  
   599  	return nil, sql.ErrInvalidArgumentDetails.New("bit_length", fmt.Sprint(arg))
   600  }
   601  
   602  // WithChildren implements the sql.Expression interface
   603  func (h *Bitlength) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   604  	if len(children) != 1 {
   605  		return nil, sql.ErrInvalidChildrenNumber.New(h, len(children), 1)
   606  	}
   607  	return NewBitlength(children[0]), nil
   608  }