github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/reverse_repeat_replace.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package function
    16  
    17  import (
    18  	"fmt"
    19  	"strings"
    20  
    21  	"gopkg.in/src-d/go-errors.v1"
    22  
    23  	"github.com/dolthub/go-mysql-server/sql"
    24  	"github.com/dolthub/go-mysql-server/sql/expression"
    25  	"github.com/dolthub/go-mysql-server/sql/types"
    26  )
    27  
    28  // Reverse is a function that returns the reverse of the text provided.
    29  type Reverse struct {
    30  	expression.UnaryExpression
    31  }
    32  
    33  var _ sql.FunctionExpression = (*Reverse)(nil)
    34  var _ sql.CollationCoercible = (*Reverse)(nil)
    35  
    36  // NewReverse creates a new Reverse expression.
    37  func NewReverse(e sql.Expression) sql.Expression {
    38  	return &Reverse{expression.UnaryExpression{Child: e}}
    39  }
    40  
    41  // FunctionName implements sql.FunctionExpression
    42  func (r *Reverse) FunctionName() string {
    43  	return "reverse"
    44  }
    45  
    46  // Description implements sql.FunctionExpression
    47  func (r *Reverse) Description() string {
    48  	return "returns the string str with the order of the characters reversed."
    49  }
    50  
    51  // Eval implements the Expression interface.
    52  func (r *Reverse) Eval(
    53  	ctx *sql.Context,
    54  	row sql.Row,
    55  ) (interface{}, error) {
    56  	//TODO: handle collations
    57  	v, err := r.Child.Eval(ctx, row)
    58  	if v == nil || err != nil {
    59  		return nil, err
    60  	}
    61  
    62  	v, _, err = types.LongText.Convert(v)
    63  	if err != nil {
    64  		return nil, err
    65  	}
    66  
    67  	return reverseString(v.(string)), nil
    68  }
    69  
    70  func reverseString(s string) string {
    71  	r := []rune(s)
    72  	for i, j := 0, len(r)-1; i < j; i, j = i+1, j-1 {
    73  		r[i], r[j] = r[j], r[i]
    74  	}
    75  	return string(r)
    76  }
    77  
    78  func (r *Reverse) String() string {
    79  	return fmt.Sprintf("reverse(%s)", r.Child)
    80  }
    81  
    82  // WithChildren implements the Expression interface.
    83  func (r *Reverse) WithChildren(children ...sql.Expression) (sql.Expression, error) {
    84  	if len(children) != 1 {
    85  		return nil, sql.ErrInvalidChildrenNumber.New(r, len(children), 1)
    86  	}
    87  	return NewReverse(children[0]), nil
    88  }
    89  
    90  // Type implements the Expression interface.
    91  func (r *Reverse) Type() sql.Type {
    92  	return r.Child.Type()
    93  }
    94  
    95  // CollationCoercibility implements the interface sql.CollationCoercible.
    96  func (r *Reverse) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
    97  	return sql.GetCoercibility(ctx, r.Child)
    98  }
    99  
   100  var ErrNegativeRepeatCount = errors.NewKind("negative Repeat count: %v")
   101  
   102  // Repeat is a function that returns the string repeated n times.
   103  type Repeat struct {
   104  	expression.BinaryExpressionStub
   105  }
   106  
   107  var _ sql.FunctionExpression = (*Repeat)(nil)
   108  var _ sql.CollationCoercible = (*Repeat)(nil)
   109  
   110  // NewRepeat creates a new Repeat expression.
   111  func NewRepeat(str sql.Expression, count sql.Expression) sql.Expression {
   112  	return &Repeat{expression.BinaryExpressionStub{LeftChild: str, RightChild: count}}
   113  }
   114  
   115  // FunctionName implements sql.FunctionExpression
   116  func (r *Repeat) FunctionName() string {
   117  	return "repeat"
   118  }
   119  
   120  // Description implements sql.FunctionExpression
   121  func (r *Repeat) Description() string {
   122  	return "returns a string consisting of the string str repeated count times."
   123  }
   124  
   125  func (r *Repeat) String() string {
   126  	return fmt.Sprintf("repeat(%s, %s)", r.LeftChild, r.RightChild)
   127  }
   128  
   129  // Type implements the Expression interface.
   130  func (r *Repeat) Type() sql.Type {
   131  	return types.LongText
   132  }
   133  
   134  // CollationCoercibility implements the interface sql.CollationCoercible.
   135  func (r *Repeat) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   136  	leftCollation, leftCoercibility := sql.GetCoercibility(ctx, r.LeftChild)
   137  	rightCollation, rightCoercibility := sql.GetCoercibility(ctx, r.RightChild)
   138  	return sql.ResolveCoercibility(leftCollation, leftCoercibility, rightCollation, rightCoercibility)
   139  }
   140  
   141  // WithChildren implements the Expression interface.
   142  func (r *Repeat) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   143  	if len(children) != 2 {
   144  		return nil, sql.ErrInvalidChildrenNumber.New(r, len(children), 2)
   145  	}
   146  	return NewRepeat(children[0], children[1]), nil
   147  }
   148  
   149  // Eval implements the Expression interface.
   150  func (r *Repeat) Eval(
   151  	ctx *sql.Context,
   152  	row sql.Row,
   153  ) (interface{}, error) {
   154  	//TODO: handle collations
   155  	str, err := r.LeftChild.Eval(ctx, row)
   156  	if str == nil || err != nil {
   157  		return nil, err
   158  	}
   159  
   160  	str, _, err = types.LongText.Convert(str)
   161  	if err != nil {
   162  		return nil, err
   163  	}
   164  
   165  	count, err := r.RightChild.Eval(ctx, row)
   166  	if count == nil || err != nil {
   167  		return nil, err
   168  	}
   169  
   170  	count, _, err = types.Int32.Convert(count)
   171  	if err != nil {
   172  		return nil, err
   173  	}
   174  	if count.(int32) < 0 {
   175  		return nil, ErrNegativeRepeatCount.New(count)
   176  	}
   177  	return strings.Repeat(str.(string), int(count.(int32))), nil
   178  }
   179  
   180  // Replace is a function that returns a string with all occurrences of fromStr replaced by the
   181  // string toStr
   182  type Replace struct {
   183  	str     sql.Expression
   184  	fromStr sql.Expression
   185  	toStr   sql.Expression
   186  }
   187  
   188  var _ sql.FunctionExpression = (*Replace)(nil)
   189  var _ sql.CollationCoercible = (*Replace)(nil)
   190  
   191  // NewReplace creates a new Replace expression.
   192  func NewReplace(str sql.Expression, fromStr sql.Expression, toStr sql.Expression) sql.Expression {
   193  	return &Replace{str, fromStr, toStr}
   194  }
   195  
   196  // FunctionName implements sql.FunctionExpression
   197  func (r *Replace) FunctionName() string {
   198  	return "replace"
   199  }
   200  
   201  // Description implements sql.FunctionExpression
   202  func (r *Replace) Description() string {
   203  	return "returns the string str with all occurrences of the string from_str replaced by the string to_str."
   204  }
   205  
   206  // Children implements the Expression interface.
   207  func (r *Replace) Children() []sql.Expression {
   208  	return []sql.Expression{r.str, r.fromStr, r.toStr}
   209  }
   210  
   211  // Resolved implements the Expression interface.
   212  func (r *Replace) Resolved() bool {
   213  	return r.str.Resolved() && r.fromStr.Resolved() && r.toStr.Resolved()
   214  }
   215  
   216  // IsNullable implements the Expression interface.
   217  func (r *Replace) IsNullable() bool {
   218  	return r.str.IsNullable() || r.fromStr.IsNullable() || r.toStr.IsNullable()
   219  }
   220  
   221  func (r *Replace) String() string {
   222  	return fmt.Sprintf("replace(%s, %s, %s)", r.str, r.fromStr, r.toStr)
   223  }
   224  
   225  // Type implements the Expression interface.
   226  func (r *Replace) Type() sql.Type {
   227  	return types.LongText
   228  }
   229  
   230  // CollationCoercibility implements the interface sql.CollationCoercible.
   231  func (r *Replace) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   232  	collation, coercibility = sql.GetCoercibility(ctx, r.str)
   233  	otherCollation, otherCoercibility := sql.GetCoercibility(ctx, r.fromStr)
   234  	collation, coercibility = sql.ResolveCoercibility(collation, coercibility, otherCollation, otherCoercibility)
   235  	otherCollation, otherCoercibility = sql.GetCoercibility(ctx, r.toStr)
   236  	return sql.ResolveCoercibility(collation, coercibility, otherCollation, otherCoercibility)
   237  }
   238  
   239  // WithChildren implements the Expression interface.
   240  func (r *Replace) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   241  	if len(children) != 3 {
   242  		return nil, sql.ErrInvalidChildrenNumber.New(r, len(children), 3)
   243  	}
   244  	return NewReplace(children[0], children[1], children[2]), nil
   245  }
   246  
   247  // Eval implements the Expression interface.
   248  func (r *Replace) Eval(
   249  	ctx *sql.Context,
   250  	row sql.Row,
   251  ) (interface{}, error) {
   252  	//TODO: handle collations
   253  	str, err := r.str.Eval(ctx, row)
   254  	if str == nil || err != nil {
   255  		return nil, err
   256  	}
   257  
   258  	str, _, err = types.LongText.Convert(str)
   259  	if err != nil {
   260  		return nil, err
   261  	}
   262  
   263  	fromStr, err := r.fromStr.Eval(ctx, row)
   264  	if fromStr == nil || err != nil {
   265  		return nil, err
   266  	}
   267  
   268  	fromStr, _, err = types.LongText.Convert(fromStr)
   269  	if err != nil {
   270  		return nil, err
   271  	}
   272  
   273  	toStr, err := r.toStr.Eval(ctx, row)
   274  	if toStr == nil || err != nil {
   275  		return nil, err
   276  	}
   277  
   278  	toStr, _, err = types.LongText.Convert(toStr)
   279  	if err != nil {
   280  		return nil, err
   281  	}
   282  
   283  	if fromStr.(string) == "" {
   284  		return str, nil
   285  	}
   286  
   287  	return strings.Replace(str.(string), fromStr.(string), toStr.(string), -1), nil
   288  }