vitess.io/vitess@v0.16.2/go/vt/vtgate/evalengine/string.go (about)

     1  /*
     2  Copyright 2022 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package evalengine
    18  
    19  import (
    20  	"bytes"
    21  
    22  	"vitess.io/vitess/go/mysql/collations"
    23  	"vitess.io/vitess/go/sqltypes"
    24  	vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
    25  	"vitess.io/vitess/go/vt/vterrors"
    26  )
    27  
    28  type builtinLower struct{}
    29  
    30  func (builtinLower) call(env *ExpressionEnv, args []EvalResult, result *EvalResult) {
    31  	inarg := &args[0]
    32  
    33  	switch {
    34  	case inarg.isNull():
    35  		result.setNull()
    36  
    37  	case sqltypes.IsNumber(inarg.typeof()):
    38  		inarg.makeTextual(env.DefaultCollation)
    39  		result.setRaw(sqltypes.VarChar, inarg.bytes(), inarg.collation())
    40  
    41  	default:
    42  		coll := collations.Local().LookupByID(inarg.collation().Collation)
    43  		csa, ok := coll.(collations.CaseAwareCollation)
    44  		if !ok {
    45  			throwEvalError(vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "not implemented"))
    46  		}
    47  
    48  		dst := csa.ToLower(nil, inarg.bytes())
    49  		result.setRaw(sqltypes.VarChar, dst, inarg.collation())
    50  	}
    51  }
    52  
    53  func (builtinLower) typeof(env *ExpressionEnv, args []Expr) (sqltypes.Type, flag) {
    54  	if len(args) != 1 {
    55  		throwArgError("LOWER")
    56  	}
    57  	_, f := args[0].typeof(env)
    58  	return sqltypes.VarChar, f
    59  }
    60  
    61  type builtinLcase struct {
    62  	builtinLower
    63  }
    64  
    65  func (builtinLcase) typeof(env *ExpressionEnv, args []Expr) (sqltypes.Type, flag) {
    66  	if len(args) != 1 {
    67  		throwArgError("LCASE")
    68  	}
    69  	_, f := args[0].typeof(env)
    70  	return sqltypes.VarChar, f
    71  }
    72  
    73  type builtinUpper struct{}
    74  
    75  func (builtinUpper) call(env *ExpressionEnv, args []EvalResult, result *EvalResult) {
    76  	inarg := &args[0]
    77  
    78  	switch {
    79  	case inarg.isNull():
    80  		result.setNull()
    81  
    82  	case sqltypes.IsNumber(inarg.typeof()):
    83  		inarg.makeTextual(env.DefaultCollation)
    84  		result.setRaw(sqltypes.VarChar, inarg.bytes(), inarg.collation())
    85  
    86  	default:
    87  		coll := collations.Local().LookupByID(inarg.collation().Collation)
    88  		csa, ok := coll.(collations.CaseAwareCollation)
    89  		if !ok {
    90  			throwEvalError(vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "not implemented"))
    91  		}
    92  
    93  		dst := csa.ToUpper(nil, inarg.bytes())
    94  		result.setRaw(sqltypes.VarChar, dst, inarg.collation())
    95  	}
    96  }
    97  
    98  func (builtinUpper) typeof(env *ExpressionEnv, args []Expr) (sqltypes.Type, flag) {
    99  	if len(args) != 1 {
   100  		throwArgError("UPPER")
   101  	}
   102  	_, f := args[0].typeof(env)
   103  	return sqltypes.VarChar, f
   104  }
   105  
   106  type builtinUcase struct {
   107  	builtinUpper
   108  }
   109  
   110  func (builtinUcase) typeof(env *ExpressionEnv, args []Expr) (sqltypes.Type, flag) {
   111  	if len(args) != 1 {
   112  		throwArgError("UCASE")
   113  	}
   114  	_, f := args[0].typeof(env)
   115  	return sqltypes.VarChar, f
   116  }
   117  
   118  type builtinCharLength struct{}
   119  
   120  func (builtinCharLength) call(env *ExpressionEnv, args []EvalResult, result *EvalResult) {
   121  	inarg := &args[0]
   122  	if inarg.isNull() {
   123  		result.setNull()
   124  		return
   125  	}
   126  
   127  	coll := collations.Local().LookupByID(inarg.collation().Collation)
   128  	cnt := collations.Length(coll, inarg.toRawBytes())
   129  	result.setInt64(int64(cnt))
   130  }
   131  
   132  func (builtinCharLength) typeof(env *ExpressionEnv, args []Expr) (sqltypes.Type, flag) {
   133  	if len(args) != 1 {
   134  		throwArgError("CHAR_LENGTH")
   135  	}
   136  	_, f := args[0].typeof(env)
   137  	return sqltypes.Int64, f
   138  }
   139  
   140  type builtinCharacterLength struct {
   141  	builtinCharLength
   142  }
   143  
   144  func (builtinCharacterLength) typeof(env *ExpressionEnv, args []Expr) (sqltypes.Type, flag) {
   145  	if len(args) != 1 {
   146  		throwArgError("CHARACTER_LENGTH")
   147  	}
   148  	_, f := args[0].typeof(env)
   149  	return sqltypes.Int64, f
   150  }
   151  
   152  type builtinOctetLength struct{}
   153  
   154  func (builtinOctetLength) call(env *ExpressionEnv, args []EvalResult, result *EvalResult) {
   155  	inarg := &args[0]
   156  	if inarg.isNull() {
   157  		result.setNull()
   158  		return
   159  	}
   160  
   161  	cnt := len(inarg.toRawBytes())
   162  	result.setInt64(int64(cnt))
   163  }
   164  
   165  func (builtinOctetLength) typeof(env *ExpressionEnv, args []Expr) (sqltypes.Type, flag) {
   166  	if len(args) != 1 {
   167  		throwArgError("LENGTH")
   168  	}
   169  	_, f := args[0].typeof(env)
   170  	return sqltypes.Int64, f
   171  }
   172  
   173  type builtinLength struct {
   174  	builtinOctetLength
   175  }
   176  
   177  func (builtinLength) typeof(env *ExpressionEnv, args []Expr) (sqltypes.Type, flag) {
   178  	if len(args) != 1 {
   179  		throwArgError("OCTET_LENGTH")
   180  	}
   181  	_, f := args[0].typeof(env)
   182  	return sqltypes.Int64, f
   183  }
   184  
   185  type builtinBitLength struct {
   186  }
   187  
   188  func (builtinBitLength) call(env *ExpressionEnv, args []EvalResult, result *EvalResult) {
   189  	inarg := &args[0]
   190  	if inarg.isNull() {
   191  		result.setNull()
   192  		return
   193  	}
   194  
   195  	cnt := len(inarg.toRawBytes())
   196  	result.setInt64(int64(cnt * 8))
   197  }
   198  
   199  func (builtinBitLength) typeof(env *ExpressionEnv, args []Expr) (sqltypes.Type, flag) {
   200  	if len(args) != 1 {
   201  		throwArgError("BIT_LENGTH")
   202  	}
   203  	_, f := args[0].typeof(env)
   204  	return sqltypes.Int64, f
   205  }
   206  
   207  type builtinASCII struct {
   208  }
   209  
   210  func (builtinASCII) call(env *ExpressionEnv, args []EvalResult, result *EvalResult) {
   211  	inarg := &args[0]
   212  	if inarg.isNull() {
   213  		result.setNull()
   214  		return
   215  	}
   216  
   217  	inarg.makeBinary()
   218  	bs := inarg.bytes()
   219  	if len(bs) > 0 {
   220  		result.setInt64(int64(bs[0]))
   221  	} else {
   222  		result.setInt64(0)
   223  	}
   224  }
   225  
   226  func (builtinASCII) typeof(env *ExpressionEnv, args []Expr) (sqltypes.Type, flag) {
   227  	if len(args) != 1 {
   228  		throwArgError("ASCII")
   229  	}
   230  	_, f := args[0].typeof(env)
   231  	return sqltypes.Int64, f
   232  }
   233  
   234  type builtinRepeat struct {
   235  }
   236  
   237  func (builtinRepeat) call(env *ExpressionEnv, args []EvalResult, result *EvalResult) {
   238  	inarg := &args[0]
   239  	repeatTime := &args[1]
   240  	if inarg.isNull() || repeatTime.isNull() {
   241  		result.setNull()
   242  		return
   243  	}
   244  
   245  	if sqltypes.IsNumber(inarg.typeof()) {
   246  		inarg.makeTextual(env.DefaultCollation)
   247  	}
   248  
   249  	repeatTime.makeSignedIntegral()
   250  	repeat := int(repeatTime.int64())
   251  	if repeat < 0 {
   252  		repeat = 0
   253  	}
   254  
   255  	result.setRaw(sqltypes.VarChar, bytes.Repeat(inarg.bytes(), repeat), inarg.collation())
   256  }
   257  
   258  func (builtinRepeat) typeof(env *ExpressionEnv, args []Expr) (sqltypes.Type, flag) {
   259  	if len(args) != 2 {
   260  		throwArgError("REPEAT")
   261  	}
   262  	_, f1 := args[0].typeof(env)
   263  	// typecheck the right-hand argument but ignore its flags
   264  	args[1].typeof(env)
   265  
   266  	return sqltypes.VarChar, f1
   267  }