storj.io/minio@v0.0.0-20230509071714-0cbc90f649b1/pkg/s3select/sql/funceval.go (about)

     1  /*
     2   * MinIO Cloud Storage, (C) 2019 MinIO, Inc.
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package sql
    18  
    19  import (
    20  	"errors"
    21  	"fmt"
    22  	"strconv"
    23  	"strings"
    24  	"time"
    25  )
    26  
    27  // FuncName - SQL function name.
    28  type FuncName string
    29  
    30  // SQL Function name constants
    31  const (
    32  	// Conditionals
    33  	sqlFnCoalesce FuncName = "COALESCE"
    34  	sqlFnNullIf   FuncName = "NULLIF"
    35  
    36  	// Conversion
    37  	sqlFnCast FuncName = "CAST"
    38  
    39  	// Date and time
    40  	sqlFnDateAdd     FuncName = "DATE_ADD"
    41  	sqlFnDateDiff    FuncName = "DATE_DIFF"
    42  	sqlFnExtract     FuncName = "EXTRACT"
    43  	sqlFnToString    FuncName = "TO_STRING"
    44  	sqlFnToTimestamp FuncName = "TO_TIMESTAMP"
    45  	sqlFnUTCNow      FuncName = "UTCNOW"
    46  
    47  	// String
    48  	sqlFnCharLength      FuncName = "CHAR_LENGTH"
    49  	sqlFnCharacterLength FuncName = "CHARACTER_LENGTH"
    50  	sqlFnLower           FuncName = "LOWER"
    51  	sqlFnSubstring       FuncName = "SUBSTRING"
    52  	sqlFnTrim            FuncName = "TRIM"
    53  	sqlFnUpper           FuncName = "UPPER"
    54  )
    55  
    56  var (
    57  	errUnimplementedCast = errors.New("This cast not yet implemented")
    58  	errNonStringTrimArg  = errors.New("TRIM() received a non-string argument")
    59  	errNonTimestampArg   = errors.New("Expected a timestamp argument")
    60  )
    61  
    62  func (e *FuncExpr) getFunctionName() FuncName {
    63  	switch {
    64  	case e.SFunc != nil:
    65  		return FuncName(strings.ToUpper(e.SFunc.FunctionName))
    66  	case e.Count != nil:
    67  		return FuncName(aggFnCount)
    68  	case e.Cast != nil:
    69  		return sqlFnCast
    70  	case e.Substring != nil:
    71  		return sqlFnSubstring
    72  	case e.Extract != nil:
    73  		return sqlFnExtract
    74  	case e.Trim != nil:
    75  		return sqlFnTrim
    76  	case e.DateAdd != nil:
    77  		return sqlFnDateAdd
    78  	case e.DateDiff != nil:
    79  		return sqlFnDateDiff
    80  	default:
    81  		return ""
    82  	}
    83  }
    84  
    85  // evalSQLFnNode assumes that the FuncExpr is not an aggregation
    86  // function.
    87  func (e *FuncExpr) evalSQLFnNode(r Record, tableAlias string) (res *Value, err error) {
    88  	// Handle functions that have phrase arguments
    89  	switch e.getFunctionName() {
    90  	case sqlFnCast:
    91  		expr := e.Cast.Expr
    92  		res, err = expr.castTo(r, strings.ToUpper(e.Cast.CastType), tableAlias)
    93  		return
    94  
    95  	case sqlFnSubstring:
    96  		return handleSQLSubstring(r, e.Substring, tableAlias)
    97  
    98  	case sqlFnExtract:
    99  		return handleSQLExtract(r, e.Extract, tableAlias)
   100  
   101  	case sqlFnTrim:
   102  		return handleSQLTrim(r, e.Trim, tableAlias)
   103  
   104  	case sqlFnDateAdd:
   105  		return handleDateAdd(r, e.DateAdd, tableAlias)
   106  
   107  	case sqlFnDateDiff:
   108  		return handleDateDiff(r, e.DateDiff, tableAlias)
   109  
   110  	}
   111  
   112  	// For all simple argument functions, we evaluate the arguments here
   113  	argVals := make([]*Value, len(e.SFunc.ArgsList))
   114  	for i, arg := range e.SFunc.ArgsList {
   115  		argVals[i], err = arg.evalNode(r, tableAlias)
   116  		if err != nil {
   117  			return nil, err
   118  		}
   119  	}
   120  
   121  	switch e.getFunctionName() {
   122  	case sqlFnCoalesce:
   123  		return coalesce(argVals)
   124  
   125  	case sqlFnNullIf:
   126  		return nullif(argVals[0], argVals[1])
   127  
   128  	case sqlFnCharLength, sqlFnCharacterLength:
   129  		return charlen(argVals[0])
   130  
   131  	case sqlFnLower:
   132  		return lowerCase(argVals[0])
   133  
   134  	case sqlFnUpper:
   135  		return upperCase(argVals[0])
   136  
   137  	case sqlFnUTCNow:
   138  		return handleUTCNow()
   139  
   140  	case sqlFnToString, sqlFnToTimestamp:
   141  		// TODO: implement
   142  		fallthrough
   143  
   144  	default:
   145  		return nil, errNotImplemented
   146  	}
   147  }
   148  
   149  func coalesce(args []*Value) (res *Value, err error) {
   150  	for _, arg := range args {
   151  		if arg.IsNull() {
   152  			continue
   153  		}
   154  		return arg, nil
   155  	}
   156  	return FromNull(), nil
   157  }
   158  
   159  func nullif(v1, v2 *Value) (res *Value, err error) {
   160  	// Handle Null cases
   161  	if v1.IsNull() || v2.IsNull() {
   162  		return v1, nil
   163  	}
   164  
   165  	err = inferTypesForCmp(v1, v2)
   166  	if err != nil {
   167  		return nil, err
   168  	}
   169  
   170  	atleastOneNumeric := v1.isNumeric() || v2.isNumeric()
   171  	bothNumeric := v1.isNumeric() && v2.isNumeric()
   172  	if atleastOneNumeric || !bothNumeric {
   173  		return v1, nil
   174  	}
   175  
   176  	if v1.SameTypeAs(*v2) {
   177  		return v1, nil
   178  	}
   179  
   180  	cmpResult, cmpErr := v1.compareOp(opEq, v2)
   181  	if cmpErr != nil {
   182  		return nil, cmpErr
   183  	}
   184  
   185  	if cmpResult {
   186  		return FromNull(), nil
   187  	}
   188  
   189  	return v1, nil
   190  }
   191  
   192  func charlen(v *Value) (*Value, error) {
   193  	inferTypeAsString(v)
   194  	s, ok := v.ToString()
   195  	if !ok {
   196  		err := fmt.Errorf("%s/%s expects a string argument", sqlFnCharLength, sqlFnCharacterLength)
   197  		return nil, errIncorrectSQLFunctionArgumentType(err)
   198  	}
   199  	return FromInt(int64(len([]rune(s)))), nil
   200  }
   201  
   202  func lowerCase(v *Value) (*Value, error) {
   203  	inferTypeAsString(v)
   204  	s, ok := v.ToString()
   205  	if !ok {
   206  		err := fmt.Errorf("%s expects a string argument", sqlFnLower)
   207  		return nil, errIncorrectSQLFunctionArgumentType(err)
   208  	}
   209  	return FromString(strings.ToLower(s)), nil
   210  }
   211  
   212  func upperCase(v *Value) (*Value, error) {
   213  	inferTypeAsString(v)
   214  	s, ok := v.ToString()
   215  	if !ok {
   216  		err := fmt.Errorf("%s expects a string argument", sqlFnUpper)
   217  		return nil, errIncorrectSQLFunctionArgumentType(err)
   218  	}
   219  	return FromString(strings.ToUpper(s)), nil
   220  }
   221  
   222  func handleDateAdd(r Record, d *DateAddFunc, tableAlias string) (*Value, error) {
   223  	q, err := d.Quantity.evalNode(r, tableAlias)
   224  	if err != nil {
   225  		return nil, err
   226  	}
   227  	inferTypeForArithOp(q)
   228  	qty, ok := q.ToFloat()
   229  	if !ok {
   230  		return nil, fmt.Errorf("QUANTITY must be a numeric argument to %s()", sqlFnDateAdd)
   231  	}
   232  
   233  	ts, err := d.Timestamp.evalNode(r, tableAlias)
   234  	if err != nil {
   235  		return nil, err
   236  	}
   237  	if err = inferTypeAsTimestamp(ts); err != nil {
   238  		return nil, err
   239  	}
   240  	t, ok := ts.ToTimestamp()
   241  	if !ok {
   242  		return nil, fmt.Errorf("%s() expects a timestamp argument", sqlFnDateAdd)
   243  	}
   244  
   245  	return dateAdd(strings.ToUpper(d.DatePart), qty, t)
   246  }
   247  
   248  func handleDateDiff(r Record, d *DateDiffFunc, tableAlias string) (*Value, error) {
   249  	tval1, err := d.Timestamp1.evalNode(r, tableAlias)
   250  	if err != nil {
   251  		return nil, err
   252  	}
   253  	if err = inferTypeAsTimestamp(tval1); err != nil {
   254  		return nil, err
   255  	}
   256  	ts1, ok := tval1.ToTimestamp()
   257  	if !ok {
   258  		return nil, fmt.Errorf("%s() expects two timestamp arguments", sqlFnDateDiff)
   259  	}
   260  
   261  	tval2, err := d.Timestamp2.evalNode(r, tableAlias)
   262  	if err != nil {
   263  		return nil, err
   264  	}
   265  	if err = inferTypeAsTimestamp(tval2); err != nil {
   266  		return nil, err
   267  	}
   268  	ts2, ok := tval2.ToTimestamp()
   269  	if !ok {
   270  		return nil, fmt.Errorf("%s() expects two timestamp arguments", sqlFnDateDiff)
   271  	}
   272  
   273  	return dateDiff(strings.ToUpper(d.DatePart), ts1, ts2)
   274  }
   275  
   276  func handleUTCNow() (*Value, error) {
   277  	return FromTimestamp(time.Now().UTC()), nil
   278  }
   279  
   280  func handleSQLSubstring(r Record, e *SubstringFunc, tableAlias string) (val *Value, err error) {
   281  	// Both forms `SUBSTRING('abc' FROM 2 FOR 1)` and
   282  	// SUBSTRING('abc', 2, 1) are supported.
   283  
   284  	// Evaluate the string argument
   285  	v1, err := e.Expr.evalNode(r, tableAlias)
   286  	if err != nil {
   287  		return nil, err
   288  	}
   289  	inferTypeAsString(v1)
   290  	s, ok := v1.ToString()
   291  	if !ok {
   292  		err := fmt.Errorf("Incorrect argument type passed to %s", sqlFnSubstring)
   293  		return nil, errIncorrectSQLFunctionArgumentType(err)
   294  	}
   295  
   296  	// Assemble other arguments
   297  	arg2, arg3 := e.From, e.For
   298  	// Check if the second form of substring is being used
   299  	if e.From == nil {
   300  		arg2, arg3 = e.Arg2, e.Arg3
   301  	}
   302  
   303  	// Evaluate the FROM argument
   304  	v2, err := arg2.evalNode(r, tableAlias)
   305  	if err != nil {
   306  		return nil, err
   307  	}
   308  	inferTypeForArithOp(v2)
   309  	startIdx, ok := v2.ToInt()
   310  	if !ok {
   311  		err := fmt.Errorf("Incorrect type for start index argument in %s", sqlFnSubstring)
   312  		return nil, errIncorrectSQLFunctionArgumentType(err)
   313  	}
   314  
   315  	length := -1
   316  	// Evaluate the optional FOR argument
   317  	if arg3 != nil {
   318  		v3, err := arg3.evalNode(r, tableAlias)
   319  		if err != nil {
   320  			return nil, err
   321  		}
   322  		inferTypeForArithOp(v3)
   323  		lenInt, ok := v3.ToInt()
   324  		if !ok {
   325  			err := fmt.Errorf("Incorrect type for length argument in %s", sqlFnSubstring)
   326  			return nil, errIncorrectSQLFunctionArgumentType(err)
   327  		}
   328  		length = int(lenInt)
   329  		if length < 0 {
   330  			err := fmt.Errorf("Negative length argument in %s", sqlFnSubstring)
   331  			return nil, errIncorrectSQLFunctionArgumentType(err)
   332  		}
   333  	}
   334  
   335  	res, err := evalSQLSubstring(s, int(startIdx), length)
   336  	return FromString(res), err
   337  }
   338  
   339  func handleSQLTrim(r Record, e *TrimFunc, tableAlias string) (res *Value, err error) {
   340  	chars := ""
   341  	ok := false
   342  	if e.TrimChars != nil {
   343  		charsV, cerr := e.TrimChars.evalNode(r, tableAlias)
   344  		if cerr != nil {
   345  			return nil, cerr
   346  		}
   347  		inferTypeAsString(charsV)
   348  		chars, ok = charsV.ToString()
   349  		if !ok {
   350  			return nil, errNonStringTrimArg
   351  		}
   352  	}
   353  
   354  	fromV, ferr := e.TrimFrom.evalNode(r, tableAlias)
   355  	if ferr != nil {
   356  		return nil, ferr
   357  	}
   358  	inferTypeAsString(fromV)
   359  	from, ok := fromV.ToString()
   360  	if !ok {
   361  		return nil, errNonStringTrimArg
   362  	}
   363  
   364  	result, terr := evalSQLTrim(e.TrimWhere, chars, from)
   365  	if terr != nil {
   366  		return nil, terr
   367  	}
   368  	return FromString(result), nil
   369  }
   370  
   371  func handleSQLExtract(r Record, e *ExtractFunc, tableAlias string) (res *Value, err error) {
   372  	timeVal, verr := e.From.evalNode(r, tableAlias)
   373  	if verr != nil {
   374  		return nil, verr
   375  	}
   376  
   377  	if err = inferTypeAsTimestamp(timeVal); err != nil {
   378  		return nil, err
   379  	}
   380  
   381  	t, ok := timeVal.ToTimestamp()
   382  	if !ok {
   383  		return nil, errNonTimestampArg
   384  	}
   385  
   386  	return extract(strings.ToUpper(e.Timeword), t)
   387  }
   388  
   389  func errUnsupportedCast(fromType, toType string) error {
   390  	return fmt.Errorf("Cannot cast from %v to %v", fromType, toType)
   391  }
   392  
   393  func errCastFailure(msg string) error {
   394  	return fmt.Errorf("Error casting: %s", msg)
   395  }
   396  
   397  // Allowed cast types
   398  const (
   399  	castBool      = "BOOL"
   400  	castInt       = "INT"
   401  	castInteger   = "INTEGER"
   402  	castString    = "STRING"
   403  	castFloat     = "FLOAT"
   404  	castDecimal   = "DECIMAL"
   405  	castNumeric   = "NUMERIC"
   406  	castTimestamp = "TIMESTAMP"
   407  )
   408  
   409  func (e *Expression) castTo(r Record, castType string, tableAlias string) (res *Value, err error) {
   410  	v, err := e.evalNode(r, tableAlias)
   411  	if err != nil {
   412  		return nil, err
   413  	}
   414  
   415  	switch castType {
   416  	case castInt, castInteger:
   417  		i, err := intCast(v)
   418  		return FromInt(i), err
   419  
   420  	case castFloat:
   421  		f, err := floatCast(v)
   422  		return FromFloat(f), err
   423  
   424  	case castString:
   425  		s, err := stringCast(v)
   426  		return FromString(s), err
   427  
   428  	case castTimestamp:
   429  		t, err := timestampCast(v)
   430  		return FromTimestamp(t), err
   431  
   432  	case castBool:
   433  		b, err := boolCast(v)
   434  		return FromBool(b), err
   435  
   436  	case castDecimal, castNumeric:
   437  		fallthrough
   438  
   439  	default:
   440  		return nil, errUnimplementedCast
   441  	}
   442  }
   443  
   444  func intCast(v *Value) (int64, error) {
   445  	// This conversion truncates floating point numbers to
   446  	// integer.
   447  	strToInt := func(s string) (int64, bool) {
   448  		i, errI := strconv.ParseInt(s, 10, 64)
   449  		if errI == nil {
   450  			return i, true
   451  		}
   452  		f, errF := strconv.ParseFloat(s, 64)
   453  		if errF == nil {
   454  			return int64(f), true
   455  		}
   456  		return 0, false
   457  	}
   458  
   459  	switch x := v.value.(type) {
   460  	case float64:
   461  		// Truncate fractional part
   462  		return int64(x), nil
   463  	case int64:
   464  		return x, nil
   465  	case string:
   466  		// Parse as number, truncate floating point if
   467  		// needed.
   468  		// String might contain trimming spaces, which
   469  		// needs to be trimmed.
   470  		res, ok := strToInt(strings.TrimSpace(x))
   471  		if !ok {
   472  			return 0, errCastFailure("could not parse as int")
   473  		}
   474  		return res, nil
   475  	case []byte:
   476  		// Parse as number, truncate floating point if
   477  		// needed.
   478  		// String might contain trimming spaces, which
   479  		// needs to be trimmed.
   480  		res, ok := strToInt(strings.TrimSpace(string(x)))
   481  		if !ok {
   482  			return 0, errCastFailure("could not parse as int")
   483  		}
   484  		return res, nil
   485  
   486  	default:
   487  		return 0, errUnsupportedCast(v.GetTypeString(), castInt)
   488  	}
   489  }
   490  
   491  func floatCast(v *Value) (float64, error) {
   492  	switch x := v.value.(type) {
   493  	case float64:
   494  		return x, nil
   495  	case int64:
   496  		return float64(x), nil
   497  	case string:
   498  		f, err := strconv.ParseFloat(strings.TrimSpace(x), 64)
   499  		if err != nil {
   500  			return 0, errCastFailure("could not parse as float")
   501  		}
   502  		return f, nil
   503  	case []byte:
   504  		f, err := strconv.ParseFloat(strings.TrimSpace(string(x)), 64)
   505  		if err != nil {
   506  			return 0, errCastFailure("could not parse as float")
   507  		}
   508  		return f, nil
   509  	default:
   510  		return 0, errUnsupportedCast(v.GetTypeString(), castFloat)
   511  	}
   512  }
   513  
   514  func stringCast(v *Value) (string, error) {
   515  	switch x := v.value.(type) {
   516  	case float64:
   517  		return fmt.Sprintf("%v", x), nil
   518  	case int64:
   519  		return fmt.Sprintf("%v", x), nil
   520  	case string:
   521  		return x, nil
   522  	case []byte:
   523  		return string(x), nil
   524  	case bool:
   525  		return fmt.Sprintf("%v", x), nil
   526  	case nil:
   527  		// FIXME: verify this case is correct
   528  		return "NULL", nil
   529  	}
   530  	// This does not happen
   531  	return "", errCastFailure(fmt.Sprintf("cannot cast %v to string type", v.GetTypeString()))
   532  }
   533  
   534  func timestampCast(v *Value) (t time.Time, _ error) {
   535  	switch x := v.value.(type) {
   536  	case string:
   537  		return parseSQLTimestamp(x)
   538  	case []byte:
   539  		return parseSQLTimestamp(string(x))
   540  	case time.Time:
   541  		return x, nil
   542  	default:
   543  		return t, errCastFailure(fmt.Sprintf("cannot cast %v to Timestamp type", v.GetTypeString()))
   544  	}
   545  }
   546  
   547  func boolCast(v *Value) (b bool, _ error) {
   548  	sToB := func(s string) (bool, error) {
   549  		switch s {
   550  		case "true":
   551  			return true, nil
   552  		case "false":
   553  			return false, nil
   554  		default:
   555  			return false, errCastFailure("cannot cast to Bool")
   556  		}
   557  	}
   558  	switch x := v.value.(type) {
   559  	case bool:
   560  		return x, nil
   561  	case string:
   562  		return sToB(strings.ToLower(x))
   563  	case []byte:
   564  		return sToB(strings.ToLower(string(x)))
   565  	default:
   566  		return false, errCastFailure("cannot cast %v to Bool")
   567  	}
   568  }