github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/trim_ltrim_rtrim.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package function
    16  
    17  import (
    18  	"fmt"
    19  	"reflect"
    20  	"strings"
    21  
    22  	"github.com/dolthub/vitess/go/vt/sqlparser"
    23  
    24  	"github.com/dolthub/go-mysql-server/sql"
    25  	"github.com/dolthub/go-mysql-server/sql/expression"
    26  	"github.com/dolthub/go-mysql-server/sql/types"
    27  )
    28  
    29  type Trim struct {
    30  	str sql.Expression
    31  	pat sql.Expression
    32  	dir string
    33  }
    34  
    35  var _ sql.FunctionExpression = (*Trim)(nil)
    36  var _ sql.CollationCoercible = (*Trim)(nil)
    37  
    38  func NewTrim(str sql.Expression, pat sql.Expression, dir string) sql.Expression {
    39  	return &Trim{str, pat, dir}
    40  }
    41  
    42  // FunctionName implements sql.FunctionExpression
    43  func (t *Trim) FunctionName() string {
    44  	return "trim"
    45  }
    46  
    47  // Description implements sql.FunctionExpression
    48  func (t *Trim) Description() string {
    49  	return "remove leading and trailing spaces."
    50  }
    51  
    52  // Children implements the Expression interface.
    53  func (t *Trim) Children() []sql.Expression {
    54  	return []sql.Expression{t.str, t.pat}
    55  }
    56  
    57  // Eval implements the Expression interface.
    58  func (t *Trim) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
    59  	// Evaluate pattern
    60  	pat, err := t.pat.Eval(ctx, row)
    61  	if err != nil {
    62  		return nil, err
    63  	}
    64  
    65  	// Convert pat into string
    66  	pat, _, err = types.LongText.Convert(pat)
    67  	if err != nil {
    68  		return nil, sql.ErrInvalidType.New(reflect.TypeOf(pat).String())
    69  	}
    70  
    71  	// Evaluate string value
    72  	str, err := t.str.Eval(ctx, row)
    73  	if err != nil {
    74  		return nil, err
    75  	}
    76  
    77  	// Nil string
    78  	if str == nil {
    79  		return nil, nil
    80  	}
    81  
    82  	// Convert pat into string
    83  	str, _, err = types.LongText.Convert(str)
    84  	if err != nil {
    85  		return nil, sql.ErrInvalidType.New(reflect.TypeOf(str).String())
    86  	}
    87  
    88  	start := 0
    89  	end := len(str.(string))
    90  	n := len(pat.(string))
    91  
    92  	// Empty pattern, do nothing
    93  	if n == 0 {
    94  		return str, nil
    95  	}
    96  
    97  	// Trim Leading
    98  	if t.dir == sqlparser.Leading || t.dir == sqlparser.Both {
    99  		for start+n <= end && str.(string)[start:start+n] == pat {
   100  			start += n
   101  		}
   102  	}
   103  
   104  	// Trim Trailing
   105  	if t.dir == sqlparser.Trailing || t.dir == sqlparser.Both {
   106  		for start+n <= end && str.(string)[end-n:end] == pat {
   107  			end -= n
   108  		}
   109  	}
   110  
   111  	return str.(string)[start:end], nil
   112  }
   113  
   114  // IsNullable implements the Expression interface.
   115  func (t Trim) IsNullable() bool {
   116  	return t.str.IsNullable() || t.pat.IsNullable()
   117  }
   118  
   119  func (t Trim) String() string {
   120  	if t.dir == sqlparser.Leading {
   121  		return fmt.Sprintf("trim(leading %v from %v)", t.pat, t.str)
   122  	} else if t.dir == sqlparser.Trailing {
   123  		return fmt.Sprintf("trim(trailing %v from %v)", t.pat, t.str)
   124  	} else {
   125  		if t.pat.String() == " " {
   126  			return fmt.Sprintf("trim(%v)", t.str)
   127  		}
   128  		return fmt.Sprintf("trim(both %v from %v)", t.pat, t.str)
   129  	}
   130  }
   131  
   132  func (t Trim) Resolved() bool {
   133  	return t.str.Resolved() && t.pat.Resolved() && t.pat.Resolved()
   134  }
   135  
   136  func (t Trim) Type() sql.Type { return t.str.Type() }
   137  
   138  // CollationCoercibility implements the interface sql.CollationCoercible.
   139  func (t Trim) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   140  	leftCollation, leftCoercibility := sql.GetCoercibility(ctx, t.str)
   141  	rightCollation, rightCoercibility := sql.GetCoercibility(ctx, t.pat)
   142  	return sql.ResolveCoercibility(leftCollation, leftCoercibility, rightCollation, rightCoercibility)
   143  }
   144  
   145  func (t Trim) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   146  	if len(children) != 2 {
   147  		return nil, sql.ErrInvalidChildrenNumber.New(t, len(children), 2)
   148  	}
   149  	return NewTrim(children[0], children[1], t.dir), nil
   150  }
   151  
   152  type LeftTrim struct {
   153  	expression.UnaryExpression
   154  }
   155  
   156  func NewLeftTrim(str sql.Expression) sql.Expression {
   157  	return &LeftTrim{expression.UnaryExpression{Child: str}}
   158  }
   159  
   160  var _ sql.FunctionExpression = (*LeftTrim)(nil)
   161  var _ sql.CollationCoercible = (*LeftTrim)(nil)
   162  
   163  // FunctionName implements sql.FunctionExpression
   164  func (t *LeftTrim) FunctionName() string {
   165  	return "ltrim"
   166  }
   167  
   168  // Description implements sql.FunctionExpression
   169  func (t *LeftTrim) Description() string {
   170  	return "returns the string str with leading space characters removed."
   171  }
   172  
   173  func (t *LeftTrim) Type() sql.Type { return t.Child.Type() }
   174  
   175  // CollationCoercibility implements the interface sql.CollationCoercible.
   176  func (t *LeftTrim) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   177  	return sql.GetCoercibility(ctx, t.Child)
   178  }
   179  
   180  func (t *LeftTrim) String() string {
   181  	return fmt.Sprintf("ltrim(%s)", t.Child)
   182  }
   183  
   184  func (t *LeftTrim) IsNullable() bool {
   185  	return t.Child.IsNullable()
   186  }
   187  
   188  func (t *LeftTrim) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   189  	if len(children) != 1 {
   190  		return nil, sql.ErrInvalidChildrenNumber.New(t, len(children), 1)
   191  	}
   192  	return NewLeftTrim(children[0]), nil
   193  }
   194  
   195  func (t *LeftTrim) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
   196  	str, err := t.Child.Eval(ctx, row)
   197  	if err != nil {
   198  		return nil, err
   199  	}
   200  
   201  	if str == nil {
   202  		return nil, nil
   203  	}
   204  
   205  	str, _, err = types.LongText.Convert(str)
   206  	if err != nil {
   207  		return nil, sql.ErrInvalidType.New(reflect.TypeOf(str))
   208  	}
   209  
   210  	return strings.TrimLeftFunc(str.(string), func(r rune) bool {
   211  		return r == ' '
   212  	}), nil
   213  }
   214  
   215  type RightTrim struct {
   216  	expression.UnaryExpression
   217  }
   218  
   219  func NewRightTrim(str sql.Expression) sql.Expression {
   220  	return &RightTrim{expression.UnaryExpression{Child: str}}
   221  }
   222  
   223  var _ sql.FunctionExpression = (*RightTrim)(nil)
   224  var _ sql.CollationCoercible = (*RightTrim)(nil)
   225  
   226  // FunctionName implements sql.FunctionExpression
   227  func (t *RightTrim) FunctionName() string {
   228  	return "rtrim"
   229  }
   230  
   231  // Description implements sql.FunctionExpression
   232  func (t *RightTrim) Description() string {
   233  	return "returns the string str with trailing space characters removed."
   234  }
   235  
   236  func (t *RightTrim) Type() sql.Type { return t.Child.Type() }
   237  
   238  // CollationCoercibility implements the interface sql.CollationCoercible.
   239  func (t *RightTrim) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   240  	return sql.GetCoercibility(ctx, t.Child)
   241  }
   242  
   243  func (t *RightTrim) String() string {
   244  	return fmt.Sprintf("rtrim(%s)", t.Child)
   245  }
   246  
   247  func (t *RightTrim) IsNullable() bool {
   248  	return t.Child.IsNullable()
   249  }
   250  
   251  func (t *RightTrim) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   252  	if len(children) != 1 {
   253  		return nil, sql.ErrInvalidChildrenNumber.New(t, len(children), 1)
   254  	}
   255  	return NewRightTrim(children[0]), nil
   256  }
   257  
   258  func (t *RightTrim) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
   259  	str, err := t.Child.Eval(ctx, row)
   260  	if err != nil {
   261  		return nil, err
   262  	}
   263  
   264  	if str == nil {
   265  		return nil, nil
   266  	}
   267  
   268  	str, _, err = types.LongText.Convert(str)
   269  	if err != nil {
   270  		return nil, sql.ErrInvalidType.New(reflect.TypeOf(str))
   271  	}
   272  
   273  	return strings.TrimRightFunc(str.(string), func(r rune) bool {
   274  		return r == ' '
   275  	}), nil
   276  }