github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/rpad_lpad.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package function
    16  
    17  import (
    18  	"fmt"
    19  	"reflect"
    20  	"strings"
    21  
    22  	"gopkg.in/src-d/go-errors.v1"
    23  
    24  	"github.com/dolthub/go-mysql-server/sql"
    25  	"github.com/dolthub/go-mysql-server/sql/types"
    26  )
    27  
    28  var ErrDivisionByZero = errors.NewKind("division by zero")
    29  
    30  type padType rune
    31  
    32  const (
    33  	lPadType padType = 'l'
    34  	rPadType padType = 'r'
    35  )
    36  
    37  func NewLeftPad(e ...sql.Expression) (sql.Expression, error) {
    38  	return NewPad(lPadType, e...)
    39  }
    40  
    41  func NewRightPad(e ...sql.Expression) (sql.Expression, error) {
    42  	return NewPad(rPadType, e...)
    43  }
    44  
    45  // NewPad creates a new Pad expression.
    46  func NewPad(pType padType, args ...sql.Expression) (sql.Expression, error) {
    47  	argLen := len(args)
    48  	if argLen != 3 {
    49  		return nil, sql.ErrInvalidArgumentNumber.New(string(pType)+"pad", "3", argLen)
    50  	}
    51  
    52  	return &Pad{args[0], args[1], args[2], pType}, nil
    53  }
    54  
    55  // Pad is a function that pads a string with another string.
    56  type Pad struct {
    57  	str     sql.Expression
    58  	length  sql.Expression
    59  	padStr  sql.Expression
    60  	padType padType
    61  }
    62  
    63  var _ sql.FunctionExpression = (*Pad)(nil)
    64  var _ sql.CollationCoercible = (*Pad)(nil)
    65  
    66  // FunctionName implements sql.FunctionExpression
    67  func (p *Pad) FunctionName() string {
    68  	if p.padType == lPadType {
    69  		return "lpad"
    70  	} else if p.padType == rPadType {
    71  		return "rpad"
    72  	} else {
    73  		panic("unknown name for pad type")
    74  	}
    75  }
    76  
    77  // Description implements sql.FunctionExpression
    78  func (p *Pad) Description() string {
    79  	if p.padType == lPadType {
    80  		return "returns the string str, left-padded with the string padstr to a length of len characters."
    81  	} else if p.padType == rPadType {
    82  		return "returns the string str, right-padded with the string padstr to a length of len characters."
    83  	} else {
    84  		panic("unknown description for pad type")
    85  	}
    86  }
    87  
    88  // Children implements the Expression interface.
    89  func (p *Pad) Children() []sql.Expression {
    90  	return []sql.Expression{p.str, p.length, p.padStr}
    91  }
    92  
    93  // Resolved implements the Expression interface.
    94  func (p *Pad) Resolved() bool {
    95  	return p.str.Resolved() && p.length.Resolved() && (p.padStr.Resolved())
    96  }
    97  
    98  // IsNullable implements the Expression interface.
    99  func (p *Pad) IsNullable() bool {
   100  	return p.str.IsNullable() || p.length.IsNullable() || p.padStr.IsNullable()
   101  }
   102  
   103  // Type implements the Expression interface.
   104  func (p *Pad) Type() sql.Type { return types.LongText }
   105  
   106  // CollationCoercibility implements the interface sql.CollationCoercible.
   107  func (p *Pad) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   108  	leftCollation, leftCoercibility := sql.GetCoercibility(ctx, p.str)
   109  	rightCollation, rightCoercibility := sql.GetCoercibility(ctx, p.padStr)
   110  	return sql.ResolveCoercibility(leftCollation, leftCoercibility, rightCollation, rightCoercibility)
   111  }
   112  
   113  func (p *Pad) String() string {
   114  	if p.padType == lPadType {
   115  		return fmt.Sprintf("lpad(%s, %s, %s)", p.str, p.length, p.padStr)
   116  	}
   117  	return fmt.Sprintf("rpad(%s, %s, %s)", p.str, p.length, p.padStr)
   118  }
   119  
   120  // WithChildren implements the Expression interface.
   121  func (p *Pad) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   122  	return NewPad(p.padType, children...)
   123  }
   124  
   125  // Eval implements the Expression interface.
   126  func (p *Pad) Eval(
   127  	ctx *sql.Context,
   128  	row sql.Row,
   129  ) (interface{}, error) {
   130  	str, err := p.str.Eval(ctx, row)
   131  	if err != nil {
   132  		return nil, err
   133  	}
   134  
   135  	if str == nil {
   136  		return nil, nil
   137  	}
   138  
   139  	str, _, err = types.LongText.Convert(str)
   140  	if err != nil {
   141  		return nil, sql.ErrInvalidType.New(reflect.TypeOf(str))
   142  	}
   143  
   144  	length, err := p.length.Eval(ctx, row)
   145  	if err != nil {
   146  		return nil, err
   147  	}
   148  
   149  	if length == nil {
   150  		return nil, nil
   151  	}
   152  
   153  	length, _, err = types.Int64.Convert(length)
   154  	if err != nil {
   155  		return nil, err
   156  	}
   157  
   158  	padStr, err := p.padStr.Eval(ctx, row)
   159  	if err != nil {
   160  		return nil, err
   161  	}
   162  
   163  	if padStr == nil {
   164  		return nil, nil
   165  	}
   166  
   167  	padStr, _, err = types.LongText.Convert(padStr)
   168  	if err != nil {
   169  		return nil, err
   170  	}
   171  
   172  	return padString(str.(string), length.(int64), padStr.(string), p.padType)
   173  }
   174  
   175  func padString(str string, length int64, padStr string, padType padType) (string, error) {
   176  	if length <= 0 {
   177  		return "", nil
   178  	}
   179  	if int64(len(str)) >= length {
   180  		return str[:length], nil
   181  	}
   182  	if len(padStr) == 0 {
   183  		return "", nil
   184  	}
   185  
   186  	padLen := int(length - int64(len(str)))
   187  	quo, rem, err := divmod(int64(padLen), int64(len(padStr)))
   188  	if err != nil {
   189  		return "", err
   190  	}
   191  
   192  	if padType == lPadType {
   193  		result := strings.Repeat(padStr, int(quo)) + padStr[:rem] + str
   194  		return result[:length], nil
   195  	}
   196  	result := str + strings.Repeat(padStr, int(quo)) + padStr[:rem]
   197  	return result[(int64(len(result)) - length):], nil
   198  }
   199  
   200  func divmod(a, b int64) (quotient, remainder int64, err error) {
   201  	if b == 0 {
   202  		return 0, 0, ErrDivisionByZero.New()
   203  	}
   204  	quotient = a / b
   205  	remainder = a % b
   206  	return
   207  }