github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/length.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package function
    16  
    17  import (
    18  	"fmt"
    19  	"unicode/utf8"
    20  
    21  	"github.com/dolthub/go-mysql-server/sql/encodings"
    22  	"github.com/dolthub/go-mysql-server/sql/types"
    23  
    24  	"github.com/dolthub/go-mysql-server/sql"
    25  	"github.com/dolthub/go-mysql-server/sql/expression"
    26  )
    27  
    28  // Length returns the length of a string or binary content, either in bytes
    29  // or characters.
    30  type Length struct {
    31  	expression.UnaryExpression
    32  	CountType CountType
    33  }
    34  
    35  var _ sql.FunctionExpression = (*Length)(nil)
    36  var _ sql.CollationCoercible = (*Length)(nil)
    37  
    38  // CountType is the kind of length count.
    39  type CountType bool
    40  
    41  const (
    42  	// NumBytes counts the number of bytes in a string or binary content.
    43  	NumBytes = CountType(false)
    44  	// NumChars counts the number of characters in a string or binary content.
    45  	NumChars = CountType(true)
    46  )
    47  
    48  // NewLength returns a new LENGTH function.
    49  func NewLength(e sql.Expression) sql.Expression {
    50  	return &Length{expression.UnaryExpression{Child: e}, NumBytes}
    51  }
    52  
    53  // NewCharLength returns a new CHAR_LENGTH function.
    54  func NewCharLength(e sql.Expression) sql.Expression {
    55  	return &Length{expression.UnaryExpression{Child: e}, NumChars}
    56  }
    57  
    58  // FunctionName implements sql.FunctionExpression
    59  func (l *Length) FunctionName() string {
    60  	if l.CountType == NumChars {
    61  		return "character_length"
    62  	} else if l.CountType == NumBytes {
    63  		return "length"
    64  	} else {
    65  		panic("unknown name for length count type")
    66  	}
    67  }
    68  
    69  // Description implements sql.FunctionExpression
    70  func (l *Length) Description() string {
    71  	if l.CountType == NumChars {
    72  		return "returns the length of the string in characters."
    73  	} else if l.CountType == NumBytes {
    74  		return "returns the length of the string in bytes."
    75  	} else {
    76  		panic("unknown description for length count type")
    77  	}
    78  }
    79  
    80  // WithChildren implements the Expression interface.
    81  func (l *Length) WithChildren(children ...sql.Expression) (sql.Expression, error) {
    82  	if len(children) != 1 {
    83  		return nil, sql.ErrInvalidChildrenNumber.New(l, len(children), 1)
    84  	}
    85  
    86  	return &Length{expression.UnaryExpression{Child: children[0]}, l.CountType}, nil
    87  }
    88  
    89  // Type implements the sql.Expression interface.
    90  func (l *Length) Type() sql.Type { return types.Int32 }
    91  
    92  // CollationCoercibility implements the interface sql.CollationCoercible.
    93  func (*Length) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
    94  	return sql.Collation_binary, 5
    95  }
    96  
    97  func (l *Length) String() string {
    98  	if l.CountType == NumBytes {
    99  		return fmt.Sprintf("length(%s)", l.Child)
   100  	}
   101  	return fmt.Sprintf("char_length(%s)", l.Child)
   102  }
   103  
   104  func (l *Length) DebugString() string {
   105  	if l.CountType == NumBytes {
   106  		return fmt.Sprintf("length(%s)", sql.DebugString(l.Child))
   107  	}
   108  	return fmt.Sprintf("char_length(%s)", sql.DebugString(l.Child))
   109  }
   110  
   111  // Eval implements the sql.Expression interface.
   112  func (l *Length) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
   113  	val, err := l.Child.Eval(ctx, row)
   114  	if err != nil {
   115  		return nil, err
   116  	}
   117  
   118  	if val == nil {
   119  		return nil, nil
   120  	}
   121  
   122  	content, collation, err := types.ConvertToCollatedString(val, l.Child.Type())
   123  	if err != nil {
   124  		return nil, err
   125  	}
   126  	charSetEncoder := collation.CharacterSet().Encoder()
   127  	if l.CountType == NumBytes {
   128  		encodedContent, ok := charSetEncoder.Encode(encodings.StringToBytes(content))
   129  		if !ok {
   130  			return nil, fmt.Errorf("unable to re-encode string for LENGTH function")
   131  		}
   132  		return int32(len(encodedContent)), nil
   133  	} else {
   134  		contentLen := int32(0)
   135  		for len(content) > 0 {
   136  			cr, cRead := charSetEncoder.NextRune(content)
   137  			if cRead == 0 || cr == utf8.RuneError {
   138  				return 0, sql.ErrCollationMalformedString.New("checking length")
   139  			}
   140  			content = content[cRead:]
   141  			contentLen++
   142  		}
   143  		return contentLen, nil
   144  	}
   145  }