github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/substring.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package function
    16  
    17  import (
    18  	"fmt"
    19  	"reflect"
    20  	"strings"
    21  
    22  	"github.com/dolthub/go-mysql-server/sql"
    23  	"github.com/dolthub/go-mysql-server/sql/types"
    24  )
    25  
    26  // Substring is a function to return a part of a string.
    27  // This function behaves as the homonym MySQL function.
    28  // Since Go strings are UTF8, this function does not return a direct sub
    29  // string str[start:start+length], instead returns the substring of rune
    30  // s. That is, "á"[0:1] does not return a partial unicode glyph, but "á"
    31  // itself.
    32  type Substring struct {
    33  	Str   sql.Expression
    34  	Start sql.Expression
    35  	Len   sql.Expression
    36  }
    37  
    38  var _ sql.FunctionExpression = (*Substring)(nil)
    39  var _ sql.CollationCoercible = (*Substring)(nil)
    40  
    41  // NewSubstring creates a new substring UDF.
    42  func NewSubstring(args ...sql.Expression) (sql.Expression, error) {
    43  	var str, start, ln sql.Expression
    44  	switch len(args) {
    45  	case 2:
    46  		str = args[0]
    47  		start = args[1]
    48  		ln = nil
    49  	case 3:
    50  		str = args[0]
    51  		start = args[1]
    52  		ln = args[2]
    53  	default:
    54  		return nil, sql.ErrInvalidArgumentNumber.New("SUBSTRING", "2 or 3", len(args))
    55  	}
    56  	return &Substring{str, start, ln}, nil
    57  }
    58  
    59  // FunctionName implements sql.FunctionExpression
    60  func (s *Substring) FunctionName() string {
    61  	return "substring"
    62  }
    63  
    64  // Description implements sql.FunctionExpression
    65  func (s *Substring) Description() string {
    66  	return "returns a substring from the provided string starting at pos with a length of len characters. If no len is provided, all characters from pos until the end will be taken."
    67  }
    68  
    69  // Children implements the Expression interface.
    70  func (s *Substring) Children() []sql.Expression {
    71  	if s.Len == nil {
    72  		return []sql.Expression{s.Str, s.Start}
    73  	}
    74  	return []sql.Expression{s.Str, s.Start, s.Len}
    75  }
    76  
    77  // Eval implements the Expression interface.
    78  func (s *Substring) Eval(
    79  	ctx *sql.Context,
    80  	row sql.Row,
    81  ) (interface{}, error) {
    82  	str, err := s.Str.Eval(ctx, row)
    83  	if err != nil {
    84  		return nil, err
    85  	}
    86  
    87  	var text []rune
    88  	switch str := str.(type) {
    89  	case string:
    90  		text = []rune(str)
    91  	case []byte:
    92  		text = []rune(string(str))
    93  	case nil:
    94  		return nil, nil
    95  	default:
    96  		return nil, sql.ErrInvalidType.New(reflect.TypeOf(str).String())
    97  	}
    98  
    99  	start, err := s.Start.Eval(ctx, row)
   100  	if err != nil {
   101  		return nil, err
   102  	}
   103  
   104  	if start == nil {
   105  		return nil, nil
   106  	}
   107  
   108  	start, _, err = types.Int64.Convert(start)
   109  	if err != nil {
   110  		return nil, err
   111  	}
   112  
   113  	var length int64
   114  	runeCount := int64(len(text))
   115  	if s.Len != nil {
   116  		len, err := s.Len.Eval(ctx, row)
   117  		if err != nil {
   118  			return nil, err
   119  		}
   120  
   121  		if len == nil {
   122  			return nil, nil
   123  		}
   124  
   125  		len, _, err = types.Int64.Convert(len)
   126  		if err != nil {
   127  			return nil, err
   128  		}
   129  
   130  		length = len.(int64)
   131  	} else {
   132  		length = runeCount
   133  	}
   134  
   135  	var startIdx int64
   136  	if start := start.(int64); start < 0 {
   137  		startIdx = runeCount + start
   138  	} else {
   139  		startIdx = start - 1
   140  	}
   141  
   142  	if startIdx < 0 || startIdx >= runeCount || length <= 0 {
   143  		return "", nil
   144  	}
   145  
   146  	if startIdx+length > runeCount {
   147  		length = int64(runeCount) - startIdx
   148  	}
   149  
   150  	return string(text[startIdx : startIdx+length]), nil
   151  }
   152  
   153  // IsNullable implements the Expression interface.
   154  func (s *Substring) IsNullable() bool {
   155  	return s.Str.IsNullable() || s.Start.IsNullable() || (s.Len != nil && s.Len.IsNullable())
   156  }
   157  
   158  func (s *Substring) String() string {
   159  	if s.Len == nil {
   160  		return fmt.Sprintf("SUBSTRING(%s, %s)", s.Str, s.Start)
   161  	}
   162  	return fmt.Sprintf("SUBSTRING(%s, %s, %s)", s.Str, s.Start, s.Len)
   163  }
   164  
   165  // Resolved implements the Expression interface.
   166  func (s *Substring) Resolved() bool {
   167  	return s.Start.Resolved() && s.Str.Resolved() && (s.Len == nil || s.Len.Resolved())
   168  }
   169  
   170  // Type implements the Expression interface.
   171  func (s *Substring) Type() sql.Type { return s.Str.Type() }
   172  
   173  // CollationCoercibility implements the interface sql.CollationCoercible.
   174  func (s *Substring) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   175  	return sql.GetCoercibility(ctx, s.Str)
   176  }
   177  
   178  // WithChildren implements the Expression interface.
   179  func (*Substring) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   180  	return NewSubstring(children...)
   181  }
   182  
   183  // SubstringIndex returns the substring from string str before count occurrences of the delimiter delim.
   184  // If count is positive, everything to the left of the final delimiter (counting from the left) is returned.
   185  // If count is negative, everything to the right of the final delimiter (counting from the right) is returned.
   186  // SUBSTRING_INDEX() performs a case-sensitive match when searching for delim.
   187  type SubstringIndex struct {
   188  	str   sql.Expression
   189  	delim sql.Expression
   190  	count sql.Expression
   191  }
   192  
   193  var _ sql.FunctionExpression = (*SubstringIndex)(nil)
   194  var _ sql.CollationCoercible = (*SubstringIndex)(nil)
   195  
   196  // NewSubstringIndex creates a new SubstringIndex UDF.
   197  func NewSubstringIndex(str, delim, count sql.Expression) sql.Expression {
   198  	return &SubstringIndex{str, delim, count}
   199  }
   200  
   201  // FunctionName implements sql.FunctionExpression
   202  func (s *SubstringIndex) FunctionName() string {
   203  	return "substring_index"
   204  }
   205  
   206  // Description implements sql.FunctionExpression
   207  func (s *SubstringIndex) Description() string {
   208  	return "returns a substring after count appearances of delim. If count is negative, counts from the right side of the string."
   209  }
   210  
   211  // Children implements the Expression interface.
   212  func (s *SubstringIndex) Children() []sql.Expression {
   213  	return []sql.Expression{s.str, s.delim, s.count}
   214  }
   215  
   216  // Eval implements the Expression interface.
   217  func (s *SubstringIndex) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
   218  	ex, err := s.str.Eval(ctx, row)
   219  	if ex == nil || err != nil {
   220  		return nil, err
   221  	}
   222  	ex, _, err = types.LongText.Convert(ex)
   223  	if err != nil {
   224  		return nil, err
   225  	}
   226  	str, ok := ex.(string)
   227  	if !ok {
   228  		return nil, sql.ErrInvalidType.New(reflect.TypeOf(ex).String())
   229  	}
   230  
   231  	ex, err = s.delim.Eval(ctx, row)
   232  	if ex == nil || err != nil {
   233  		return nil, err
   234  	}
   235  	ex, _, err = types.LongText.Convert(ex)
   236  	if err != nil {
   237  		return nil, err
   238  	}
   239  	delim, ok := ex.(string)
   240  	if !ok {
   241  		return nil, sql.ErrInvalidType.New(reflect.TypeOf(ex).String())
   242  	}
   243  
   244  	ex, err = s.count.Eval(ctx, row)
   245  	if ex == nil || err != nil {
   246  		return nil, err
   247  	}
   248  	ex, _, err = types.Int64.Convert(ex)
   249  	if err != nil {
   250  		return nil, err
   251  	}
   252  	count, ok := ex.(int64)
   253  	if !ok {
   254  		return nil, sql.ErrInvalidType.New(reflect.TypeOf(ex).String())
   255  	}
   256  
   257  	// Implementation taken from pingcap/tidb
   258  	// https://github.com/pingcap/tidb/blob/37c128b64f3ad2f08d52bc767b6e3320ecf429d8/expression/builtin_string.go#L1229
   259  	strs := strings.Split(str, delim)
   260  	start, end := int64(0), int64(len(strs))
   261  	if count > 0 {
   262  		// If count is positive, everything to the left of the final delimiter (counting from the left) is returned.
   263  		if count < end {
   264  			end = count
   265  		}
   266  	} else {
   267  		// If count is negative, everything to the right of the final delimiter (counting from the right) is returned.
   268  		count = -count
   269  		if count < 0 {
   270  			// -count overflows max int64, returns an empty string.
   271  			return "", nil
   272  		}
   273  
   274  		if count < end {
   275  			start = end - count
   276  		}
   277  	}
   278  
   279  	return strings.Join(strs[start:end], delim), nil
   280  }
   281  
   282  // IsNullable implements the Expression interface.
   283  func (s *SubstringIndex) IsNullable() bool {
   284  	return s.str.IsNullable() || s.delim.IsNullable() || s.count.IsNullable()
   285  }
   286  
   287  func (s *SubstringIndex) String() string {
   288  	return fmt.Sprintf("SUBSTRING_INDEX(%s, %s, %s)", s.str, s.delim, s.count)
   289  }
   290  
   291  // Resolved implements the Expression interface.
   292  func (s *SubstringIndex) Resolved() bool {
   293  	return s.str.Resolved() && s.delim.Resolved() && s.count.Resolved()
   294  }
   295  
   296  // Type implements the Expression interface.
   297  func (*SubstringIndex) Type() sql.Type { return types.LongText }
   298  
   299  // CollationCoercibility implements the interface sql.CollationCoercible.
   300  func (s *SubstringIndex) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   301  	return sql.GetCoercibility(ctx, s.str)
   302  }
   303  
   304  // WithChildren implements the Expression interface.
   305  func (s *SubstringIndex) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   306  	if len(children) != 3 {
   307  		return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 3)
   308  	}
   309  	return NewSubstringIndex(children[0], children[1], children[2]), nil
   310  }
   311  
   312  // Left is a function that returns the first N characters of a string expression.
   313  type Left struct {
   314  	str sql.Expression
   315  	len sql.Expression
   316  }
   317  
   318  var _ sql.FunctionExpression = Left{}
   319  var _ sql.CollationCoercible = Left{}
   320  
   321  // NewLeft creates a new LEFT function.
   322  func NewLeft(str, len sql.Expression) sql.Expression {
   323  	return Left{str, len}
   324  }
   325  
   326  // FunctionName implements sql.FunctionExpression
   327  func (l Left) FunctionName() string {
   328  	return "left"
   329  }
   330  
   331  // Description implements sql.FunctionExpression
   332  func (l Left) Description() string {
   333  	return "returns the first N characters in the string given."
   334  }
   335  
   336  // Children implements the Expression interface.
   337  func (l Left) Children() []sql.Expression {
   338  	return []sql.Expression{l.str, l.len}
   339  }
   340  
   341  // Eval implements the Expression interface.
   342  func (l Left) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
   343  	str, err := l.str.Eval(ctx, row)
   344  	if err != nil {
   345  		return nil, err
   346  	}
   347  
   348  	var text []rune
   349  	switch str := str.(type) {
   350  	case string:
   351  		text = []rune(str)
   352  	case []byte:
   353  		text = []rune(string(str))
   354  	case nil:
   355  		return nil, nil
   356  	default:
   357  		return nil, sql.ErrInvalidType.New(reflect.TypeOf(str).String())
   358  	}
   359  
   360  	var length int64
   361  	runeCount := int64(len(text))
   362  	len, err := l.len.Eval(ctx, row)
   363  	if err != nil {
   364  		return nil, err
   365  	}
   366  
   367  	if len == nil {
   368  		return nil, nil
   369  	}
   370  
   371  	len, _, err = types.Int64.Convert(len)
   372  	if err != nil {
   373  		return nil, err
   374  	}
   375  
   376  	length = len.(int64)
   377  
   378  	if length > runeCount {
   379  		length = runeCount
   380  	}
   381  	if length <= 0 {
   382  		return "", nil
   383  	}
   384  
   385  	return string(text[:length]), nil
   386  }
   387  
   388  // IsNullable implements the Expression interface.
   389  func (l Left) IsNullable() bool {
   390  	return l.str.IsNullable() || l.len.IsNullable()
   391  }
   392  
   393  func (l Left) String() string {
   394  	return fmt.Sprintf("LEFT(%s, %s)", l.str, l.len)
   395  }
   396  
   397  // Resolved implements the Expression interface.
   398  func (l Left) Resolved() bool {
   399  	return l.str.Resolved() && l.len.Resolved()
   400  }
   401  
   402  // Type implements the Expression interface.
   403  func (Left) Type() sql.Type { return types.LongText }
   404  
   405  // CollationCoercibility implements the interface sql.CollationCoercible.
   406  func (l Left) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   407  	return sql.GetCoercibility(ctx, l.str)
   408  }
   409  
   410  // WithChildren implements the Expression interface.
   411  func (l Left) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   412  	if len(children) != 2 {
   413  		return nil, sql.ErrInvalidChildrenNumber.New(l, len(children), 2)
   414  	}
   415  	return NewLeft(children[0], children[1]), nil
   416  }
   417  
   418  // Right is a function that returns the last N characters of a string expression.
   419  type Right struct {
   420  	str sql.Expression
   421  	len sql.Expression
   422  }
   423  
   424  var _ sql.FunctionExpression = Right{}
   425  var _ sql.CollationCoercible = Right{}
   426  
   427  // NewRight creates a new RIGHT function.
   428  func NewRight(str, len sql.Expression) sql.Expression {
   429  	return Right{str, len}
   430  }
   431  
   432  // FunctionName implements sql.FunctionExpression
   433  func (r Right) FunctionName() string {
   434  	return "right"
   435  }
   436  
   437  // Description implements sql.FunctionExpression
   438  func (r Right) Description() string {
   439  	return "returns the specified rightmost number of characters."
   440  }
   441  
   442  // Children implements the Expression interface.
   443  func (r Right) Children() []sql.Expression {
   444  	return []sql.Expression{r.str, r.len}
   445  }
   446  
   447  // Eval implements the Expression interface.
   448  func (r Right) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
   449  	str, err := r.str.Eval(ctx, row)
   450  	if err != nil {
   451  		return nil, err
   452  	}
   453  
   454  	var text []rune
   455  	switch str := str.(type) {
   456  	case string:
   457  		text = []rune(str)
   458  	case []byte:
   459  		text = []rune(string(str))
   460  	case nil:
   461  		return nil, nil
   462  	default:
   463  		return nil, sql.ErrInvalidType.New(reflect.TypeOf(str).String())
   464  	}
   465  
   466  	var length int64
   467  	runeCount := int64(len(text))
   468  	len, err := r.len.Eval(ctx, row)
   469  	if err != nil {
   470  		return nil, err
   471  	}
   472  
   473  	if len == nil {
   474  		return nil, nil
   475  	}
   476  
   477  	len, _, err = types.Int64.Convert(len)
   478  	if err != nil {
   479  		return nil, err
   480  	}
   481  
   482  	length = len.(int64)
   483  
   484  	if length > runeCount {
   485  		length = runeCount
   486  	}
   487  	if length <= 0 {
   488  		return "", nil
   489  	}
   490  
   491  	return string(text[runeCount-length:]), nil
   492  }
   493  
   494  // IsNullable implements the Expression interface.
   495  func (r Right) IsNullable() bool {
   496  	return r.str.IsNullable() || r.len.IsNullable()
   497  }
   498  
   499  func (r Right) String() string {
   500  	return fmt.Sprintf("RIGHT(%s, %s)", r.str, r.len)
   501  }
   502  
   503  func (r Right) DebugString() string {
   504  	pr := sql.NewTreePrinter()
   505  	_ = pr.WriteNode("RIGHT")
   506  	children := []string{
   507  		fmt.Sprintf("str: %s", sql.DebugString(r.str)),
   508  		fmt.Sprintf("len: %s", sql.DebugString(r.len)),
   509  	}
   510  	_ = pr.WriteChildren(children...)
   511  	return pr.String()
   512  }
   513  
   514  // Resolved implements the Expression interface.
   515  func (r Right) Resolved() bool {
   516  	return r.str.Resolved() && r.len.Resolved()
   517  }
   518  
   519  // Type implements the Expression interface.
   520  func (Right) Type() sql.Type { return types.LongText }
   521  
   522  // CollationCoercibility implements the interface sql.CollationCoercible.
   523  func (r Right) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   524  	return sql.GetCoercibility(ctx, r.str)
   525  }
   526  
   527  // WithChildren implements the Expression interface.
   528  func (r Right) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   529  	if len(children) != 2 {
   530  		return nil, sql.ErrInvalidChildrenNumber.New(r, len(children), 2)
   531  	}
   532  	return NewRight(children[0], children[1]), nil
   533  }
   534  
   535  type Instr struct {
   536  	str    sql.Expression
   537  	substr sql.Expression
   538  }
   539  
   540  var _ sql.FunctionExpression = Instr{}
   541  var _ sql.CollationCoercible = Instr{}
   542  
   543  // NewInstr creates a new instr UDF.
   544  func NewInstr(str, substr sql.Expression) sql.Expression {
   545  	return Instr{str, substr}
   546  }
   547  
   548  // FunctionName implements sql.FunctionExpression
   549  func (i Instr) FunctionName() string {
   550  	return "instr"
   551  }
   552  
   553  // Description implements sql.FunctionExpression
   554  func (i Instr) Description() string {
   555  	return "returns the 1-based index of the first occurence of str2 in str1, or 0 if it does not occur."
   556  }
   557  
   558  // Children implements the Expression interface.
   559  func (i Instr) Children() []sql.Expression {
   560  	return []sql.Expression{i.str, i.substr}
   561  }
   562  
   563  // Eval implements the Expression interface.
   564  func (i Instr) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
   565  	str, err := i.str.Eval(ctx, row)
   566  	if err != nil {
   567  		return nil, err
   568  	}
   569  
   570  	var text []rune
   571  	switch str := str.(type) {
   572  	case string:
   573  		text = []rune(str)
   574  	case []byte:
   575  		text = []rune(string(str))
   576  	case nil:
   577  		return nil, nil
   578  	default:
   579  		return nil, sql.ErrInvalidType.New(reflect.TypeOf(str).String())
   580  	}
   581  
   582  	substr, err := i.substr.Eval(ctx, row)
   583  	if err != nil {
   584  		return nil, err
   585  	}
   586  
   587  	var subtext []rune
   588  	switch substr := substr.(type) {
   589  	case string:
   590  		subtext = []rune(substr)
   591  	case []byte:
   592  		subtext = []rune(string(subtext))
   593  	case nil:
   594  		return nil, nil
   595  	default:
   596  		return nil, sql.ErrInvalidType.New(reflect.TypeOf(str).String())
   597  	}
   598  
   599  	return findSubsequence(text, subtext) + 1, nil
   600  }
   601  
   602  func findSubsequence(text []rune, subtext []rune) int64 {
   603  	for i := 0; i <= len(text)-len(subtext); i++ {
   604  		var j int
   605  		for j = 0; j < len(subtext); j++ {
   606  			if text[i+j] != subtext[j] {
   607  				break
   608  			}
   609  		}
   610  		if j == len(subtext) {
   611  			return int64(i)
   612  		}
   613  	}
   614  	return -1
   615  }
   616  
   617  // IsNullable implements the Expression interface.
   618  func (i Instr) IsNullable() bool {
   619  	return i.str.IsNullable() || i.substr.IsNullable()
   620  }
   621  
   622  func (i Instr) String() string {
   623  	return fmt.Sprintf("INSTR(%s, %s)", i.str, i.substr)
   624  }
   625  
   626  // Resolved implements the Expression interface.
   627  func (i Instr) Resolved() bool {
   628  	return i.str.Resolved() && i.substr.Resolved()
   629  }
   630  
   631  // Type implements the Expression interface.
   632  func (Instr) Type() sql.Type { return types.Int64 }
   633  
   634  // CollationCoercibility implements the interface sql.CollationCoercible.
   635  func (Instr) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   636  	return sql.Collation_binary, 5
   637  }
   638  
   639  // WithChildren implements the Expression interface.
   640  func (i Instr) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   641  	if len(children) != 2 {
   642  		return nil, sql.ErrInvalidChildrenNumber.New(i, len(children), 2)
   643  	}
   644  	return NewInstr(children[0], children[1]), nil
   645  }