github.com/minio/minio@v0.0.0-20240328213742-3f72439b8a27/internal/s3select/sql/funceval.go (about)

     1  // Copyright (c) 2015-2021 MinIO, Inc.
     2  //
     3  // This file is part of MinIO Object Storage stack
     4  //
     5  // This program is free software: you can redistribute it and/or modify
     6  // it under the terms of the GNU Affero General Public License as published by
     7  // the Free Software Foundation, either version 3 of the License, or
     8  // (at your option) any later version.
     9  //
    10  // This program is distributed in the hope that it will be useful
    11  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    12  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13  // GNU Affero General Public License for more details.
    14  //
    15  // You should have received a copy of the GNU Affero General Public License
    16  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17  
    18  package sql
    19  
    20  import (
    21  	"errors"
    22  	"fmt"
    23  	"strconv"
    24  	"strings"
    25  	"time"
    26  )
    27  
    28  // FuncName - SQL function name.
    29  type FuncName string
    30  
    31  // SQL Function name constants
    32  const (
    33  	// Conditionals
    34  	sqlFnCoalesce FuncName = "COALESCE"
    35  	sqlFnNullIf   FuncName = "NULLIF"
    36  
    37  	// Conversion
    38  	sqlFnCast FuncName = "CAST"
    39  
    40  	// Date and time
    41  	sqlFnDateAdd     FuncName = "DATE_ADD"
    42  	sqlFnDateDiff    FuncName = "DATE_DIFF"
    43  	sqlFnExtract     FuncName = "EXTRACT"
    44  	sqlFnToString    FuncName = "TO_STRING"
    45  	sqlFnToTimestamp FuncName = "TO_TIMESTAMP"
    46  	sqlFnUTCNow      FuncName = "UTCNOW"
    47  
    48  	// String
    49  	sqlFnCharLength      FuncName = "CHAR_LENGTH"
    50  	sqlFnCharacterLength FuncName = "CHARACTER_LENGTH"
    51  	sqlFnLower           FuncName = "LOWER"
    52  	sqlFnSubstring       FuncName = "SUBSTRING"
    53  	sqlFnTrim            FuncName = "TRIM"
    54  	sqlFnUpper           FuncName = "UPPER"
    55  )
    56  
    57  var (
    58  	errUnimplementedCast = errors.New("This cast not yet implemented")
    59  	errNonStringTrimArg  = errors.New("TRIM() received a non-string argument")
    60  	errNonTimestampArg   = errors.New("Expected a timestamp argument")
    61  )
    62  
    63  func (e *FuncExpr) getFunctionName() FuncName {
    64  	switch {
    65  	case e.SFunc != nil:
    66  		return FuncName(strings.ToUpper(e.SFunc.FunctionName))
    67  	case e.Count != nil:
    68  		return aggFnCount
    69  	case e.Cast != nil:
    70  		return sqlFnCast
    71  	case e.Substring != nil:
    72  		return sqlFnSubstring
    73  	case e.Extract != nil:
    74  		return sqlFnExtract
    75  	case e.Trim != nil:
    76  		return sqlFnTrim
    77  	case e.DateAdd != nil:
    78  		return sqlFnDateAdd
    79  	case e.DateDiff != nil:
    80  		return sqlFnDateDiff
    81  	default:
    82  		return ""
    83  	}
    84  }
    85  
    86  // evalSQLFnNode assumes that the FuncExpr is not an aggregation
    87  // function.
    88  func (e *FuncExpr) evalSQLFnNode(r Record, tableAlias string) (res *Value, err error) {
    89  	// Handle functions that have phrase arguments
    90  	switch e.getFunctionName() {
    91  	case sqlFnCast:
    92  		expr := e.Cast.Expr
    93  		res, err = expr.castTo(r, strings.ToUpper(e.Cast.CastType), tableAlias)
    94  		return
    95  
    96  	case sqlFnSubstring:
    97  		return handleSQLSubstring(r, e.Substring, tableAlias)
    98  
    99  	case sqlFnExtract:
   100  		return handleSQLExtract(r, e.Extract, tableAlias)
   101  
   102  	case sqlFnTrim:
   103  		return handleSQLTrim(r, e.Trim, tableAlias)
   104  
   105  	case sqlFnDateAdd:
   106  		return handleDateAdd(r, e.DateAdd, tableAlias)
   107  
   108  	case sqlFnDateDiff:
   109  		return handleDateDiff(r, e.DateDiff, tableAlias)
   110  
   111  	}
   112  
   113  	// For all simple argument functions, we evaluate the arguments here
   114  	argVals := make([]*Value, len(e.SFunc.ArgsList))
   115  	for i, arg := range e.SFunc.ArgsList {
   116  		argVals[i], err = arg.evalNode(r, tableAlias)
   117  		if err != nil {
   118  			return nil, err
   119  		}
   120  	}
   121  
   122  	switch e.getFunctionName() {
   123  	case sqlFnCoalesce:
   124  		return coalesce(argVals)
   125  
   126  	case sqlFnNullIf:
   127  		return nullif(argVals[0], argVals[1])
   128  
   129  	case sqlFnCharLength, sqlFnCharacterLength:
   130  		return charlen(argVals[0])
   131  
   132  	case sqlFnLower:
   133  		return lowerCase(argVals[0])
   134  
   135  	case sqlFnUpper:
   136  		return upperCase(argVals[0])
   137  
   138  	case sqlFnUTCNow:
   139  		return handleUTCNow()
   140  
   141  	case sqlFnToString, sqlFnToTimestamp:
   142  		// TODO: implement
   143  		fallthrough
   144  
   145  	default:
   146  		return nil, errNotImplemented
   147  	}
   148  }
   149  
   150  func coalesce(args []*Value) (res *Value, err error) {
   151  	for _, arg := range args {
   152  		if arg.IsNull() {
   153  			continue
   154  		}
   155  		return arg, nil
   156  	}
   157  	return FromNull(), nil
   158  }
   159  
   160  func nullif(v1, v2 *Value) (res *Value, err error) {
   161  	// Handle Null cases
   162  	if v1.IsNull() || v2.IsNull() {
   163  		return v1, nil
   164  	}
   165  
   166  	err = inferTypesForCmp(v1, v2)
   167  	if err != nil {
   168  		return nil, err
   169  	}
   170  
   171  	atleastOneNumeric := v1.isNumeric() || v2.isNumeric()
   172  	bothNumeric := v1.isNumeric() && v2.isNumeric()
   173  	if atleastOneNumeric || !bothNumeric {
   174  		return v1, nil
   175  	}
   176  
   177  	if v1.SameTypeAs(*v2) {
   178  		return v1, nil
   179  	}
   180  
   181  	cmpResult, cmpErr := v1.compareOp(opEq, v2)
   182  	if cmpErr != nil {
   183  		return nil, cmpErr
   184  	}
   185  
   186  	if cmpResult {
   187  		return FromNull(), nil
   188  	}
   189  
   190  	return v1, nil
   191  }
   192  
   193  func charlen(v *Value) (*Value, error) {
   194  	inferTypeAsString(v)
   195  	s, ok := v.ToString()
   196  	if !ok {
   197  		err := fmt.Errorf("%s/%s expects a string argument", sqlFnCharLength, sqlFnCharacterLength)
   198  		return nil, errIncorrectSQLFunctionArgumentType(err)
   199  	}
   200  	return FromInt(int64(len([]rune(s)))), nil
   201  }
   202  
   203  func lowerCase(v *Value) (*Value, error) {
   204  	inferTypeAsString(v)
   205  	s, ok := v.ToString()
   206  	if !ok {
   207  		err := fmt.Errorf("%s expects a string argument", sqlFnLower)
   208  		return nil, errIncorrectSQLFunctionArgumentType(err)
   209  	}
   210  	return FromString(strings.ToLower(s)), nil
   211  }
   212  
   213  func upperCase(v *Value) (*Value, error) {
   214  	inferTypeAsString(v)
   215  	s, ok := v.ToString()
   216  	if !ok {
   217  		err := fmt.Errorf("%s expects a string argument", sqlFnUpper)
   218  		return nil, errIncorrectSQLFunctionArgumentType(err)
   219  	}
   220  	return FromString(strings.ToUpper(s)), nil
   221  }
   222  
   223  func handleDateAdd(r Record, d *DateAddFunc, tableAlias string) (*Value, error) {
   224  	q, err := d.Quantity.evalNode(r, tableAlias)
   225  	if err != nil {
   226  		return nil, err
   227  	}
   228  	inferTypeForArithOp(q)
   229  	qty, ok := q.ToFloat()
   230  	if !ok {
   231  		return nil, fmt.Errorf("QUANTITY must be a numeric argument to %s()", sqlFnDateAdd)
   232  	}
   233  
   234  	ts, err := d.Timestamp.evalNode(r, tableAlias)
   235  	if err != nil {
   236  		return nil, err
   237  	}
   238  	if err = inferTypeAsTimestamp(ts); err != nil {
   239  		return nil, err
   240  	}
   241  	t, ok := ts.ToTimestamp()
   242  	if !ok {
   243  		return nil, fmt.Errorf("%s() expects a timestamp argument", sqlFnDateAdd)
   244  	}
   245  
   246  	return dateAdd(strings.ToUpper(d.DatePart), qty, t)
   247  }
   248  
   249  func handleDateDiff(r Record, d *DateDiffFunc, tableAlias string) (*Value, error) {
   250  	tval1, err := d.Timestamp1.evalNode(r, tableAlias)
   251  	if err != nil {
   252  		return nil, err
   253  	}
   254  	if err = inferTypeAsTimestamp(tval1); err != nil {
   255  		return nil, err
   256  	}
   257  	ts1, ok := tval1.ToTimestamp()
   258  	if !ok {
   259  		return nil, fmt.Errorf("%s() expects two timestamp arguments", sqlFnDateDiff)
   260  	}
   261  
   262  	tval2, err := d.Timestamp2.evalNode(r, tableAlias)
   263  	if err != nil {
   264  		return nil, err
   265  	}
   266  	if err = inferTypeAsTimestamp(tval2); err != nil {
   267  		return nil, err
   268  	}
   269  	ts2, ok := tval2.ToTimestamp()
   270  	if !ok {
   271  		return nil, fmt.Errorf("%s() expects two timestamp arguments", sqlFnDateDiff)
   272  	}
   273  
   274  	return dateDiff(strings.ToUpper(d.DatePart), ts1, ts2)
   275  }
   276  
   277  func handleUTCNow() (*Value, error) {
   278  	return FromTimestamp(time.Now().UTC()), nil
   279  }
   280  
   281  func handleSQLSubstring(r Record, e *SubstringFunc, tableAlias string) (val *Value, err error) {
   282  	// Both forms `SUBSTRING('abc' FROM 2 FOR 1)` and
   283  	// SUBSTRING('abc', 2, 1) are supported.
   284  
   285  	// Evaluate the string argument
   286  	v1, err := e.Expr.evalNode(r, tableAlias)
   287  	if err != nil {
   288  		return nil, err
   289  	}
   290  	inferTypeAsString(v1)
   291  	s, ok := v1.ToString()
   292  	if !ok {
   293  		err := fmt.Errorf("Incorrect argument type passed to %s", sqlFnSubstring)
   294  		return nil, errIncorrectSQLFunctionArgumentType(err)
   295  	}
   296  
   297  	// Assemble other arguments
   298  	arg2, arg3 := e.From, e.For
   299  	// Check if the second form of substring is being used
   300  	if e.From == nil {
   301  		arg2, arg3 = e.Arg2, e.Arg3
   302  	}
   303  
   304  	// Evaluate the FROM argument
   305  	v2, err := arg2.evalNode(r, tableAlias)
   306  	if err != nil {
   307  		return nil, err
   308  	}
   309  	inferTypeForArithOp(v2)
   310  	startIdx, ok := v2.ToInt()
   311  	if !ok {
   312  		err := fmt.Errorf("Incorrect type for start index argument in %s", sqlFnSubstring)
   313  		return nil, errIncorrectSQLFunctionArgumentType(err)
   314  	}
   315  
   316  	length := -1
   317  	// Evaluate the optional FOR argument
   318  	if arg3 != nil {
   319  		v3, err := arg3.evalNode(r, tableAlias)
   320  		if err != nil {
   321  			return nil, err
   322  		}
   323  		inferTypeForArithOp(v3)
   324  		lenInt, ok := v3.ToInt()
   325  		if !ok {
   326  			err := fmt.Errorf("Incorrect type for length argument in %s", sqlFnSubstring)
   327  			return nil, errIncorrectSQLFunctionArgumentType(err)
   328  		}
   329  		length = int(lenInt)
   330  		if length < 0 {
   331  			err := fmt.Errorf("Negative length argument in %s", sqlFnSubstring)
   332  			return nil, errIncorrectSQLFunctionArgumentType(err)
   333  		}
   334  	}
   335  
   336  	res, err := evalSQLSubstring(s, int(startIdx), length)
   337  	return FromString(res), err
   338  }
   339  
   340  func handleSQLTrim(r Record, e *TrimFunc, tableAlias string) (res *Value, err error) {
   341  	chars := ""
   342  	ok := false
   343  	if e.TrimChars != nil {
   344  		charsV, cerr := e.TrimChars.evalNode(r, tableAlias)
   345  		if cerr != nil {
   346  			return nil, cerr
   347  		}
   348  		inferTypeAsString(charsV)
   349  		chars, ok = charsV.ToString()
   350  		if !ok {
   351  			return nil, errNonStringTrimArg
   352  		}
   353  	}
   354  
   355  	fromV, ferr := e.TrimFrom.evalNode(r, tableAlias)
   356  	if ferr != nil {
   357  		return nil, ferr
   358  	}
   359  	inferTypeAsString(fromV)
   360  	from, ok := fromV.ToString()
   361  	if !ok {
   362  		return nil, errNonStringTrimArg
   363  	}
   364  
   365  	result, terr := evalSQLTrim(e.TrimWhere, chars, from)
   366  	if terr != nil {
   367  		return nil, terr
   368  	}
   369  	return FromString(result), nil
   370  }
   371  
   372  func handleSQLExtract(r Record, e *ExtractFunc, tableAlias string) (res *Value, err error) {
   373  	timeVal, verr := e.From.evalNode(r, tableAlias)
   374  	if verr != nil {
   375  		return nil, verr
   376  	}
   377  
   378  	if err = inferTypeAsTimestamp(timeVal); err != nil {
   379  		return nil, err
   380  	}
   381  
   382  	t, ok := timeVal.ToTimestamp()
   383  	if !ok {
   384  		return nil, errNonTimestampArg
   385  	}
   386  
   387  	return extract(strings.ToUpper(e.Timeword), t)
   388  }
   389  
   390  func errUnsupportedCast(fromType, toType string) error {
   391  	return fmt.Errorf("Cannot cast from %v to %v", fromType, toType)
   392  }
   393  
   394  func errCastFailure(msg string) error {
   395  	return fmt.Errorf("Error casting: %s", msg)
   396  }
   397  
   398  // Allowed cast types
   399  const (
   400  	castBool      = "BOOL"
   401  	castInt       = "INT"
   402  	castInteger   = "INTEGER"
   403  	castString    = "STRING"
   404  	castFloat     = "FLOAT"
   405  	castDecimal   = "DECIMAL"
   406  	castNumeric   = "NUMERIC"
   407  	castTimestamp = "TIMESTAMP"
   408  )
   409  
   410  func (e *Expression) castTo(r Record, castType string, tableAlias string) (res *Value, err error) {
   411  	v, err := e.evalNode(r, tableAlias)
   412  	if err != nil {
   413  		return nil, err
   414  	}
   415  
   416  	switch castType {
   417  	case castInt, castInteger:
   418  		i, err := intCast(v)
   419  		return FromInt(i), err
   420  
   421  	case castFloat:
   422  		f, err := floatCast(v)
   423  		return FromFloat(f), err
   424  
   425  	case castString:
   426  		s, err := stringCast(v)
   427  		return FromString(s), err
   428  
   429  	case castTimestamp:
   430  		t, err := timestampCast(v)
   431  		return FromTimestamp(t), err
   432  
   433  	case castBool:
   434  		b, err := boolCast(v)
   435  		return FromBool(b), err
   436  
   437  	case castDecimal, castNumeric:
   438  		fallthrough
   439  
   440  	default:
   441  		return nil, errUnimplementedCast
   442  	}
   443  }
   444  
   445  func intCast(v *Value) (int64, error) {
   446  	// This conversion truncates floating point numbers to
   447  	// integer.
   448  	strToInt := func(s string) (int64, bool) {
   449  		i, errI := strconv.ParseInt(s, 10, 64)
   450  		if errI == nil {
   451  			return i, true
   452  		}
   453  		f, errF := strconv.ParseFloat(s, 64)
   454  		if errF == nil {
   455  			return int64(f), true
   456  		}
   457  		return 0, false
   458  	}
   459  
   460  	switch x := v.value.(type) {
   461  	case float64:
   462  		// Truncate fractional part
   463  		return int64(x), nil
   464  	case int64:
   465  		return x, nil
   466  	case string:
   467  		// Parse as number, truncate floating point if
   468  		// needed.
   469  		// String might contain trimming spaces, which
   470  		// needs to be trimmed.
   471  		res, ok := strToInt(strings.TrimSpace(x))
   472  		if !ok {
   473  			return 0, errCastFailure("could not parse as int")
   474  		}
   475  		return res, nil
   476  	case []byte:
   477  		// Parse as number, truncate floating point if
   478  		// needed.
   479  		// String might contain trimming spaces, which
   480  		// needs to be trimmed.
   481  		res, ok := strToInt(strings.TrimSpace(string(x)))
   482  		if !ok {
   483  			return 0, errCastFailure("could not parse as int")
   484  		}
   485  		return res, nil
   486  
   487  	default:
   488  		return 0, errUnsupportedCast(v.GetTypeString(), castInt)
   489  	}
   490  }
   491  
   492  func floatCast(v *Value) (float64, error) {
   493  	switch x := v.value.(type) {
   494  	case float64:
   495  		return x, nil
   496  	case int64:
   497  		return float64(x), nil
   498  	case string:
   499  		f, err := strconv.ParseFloat(strings.TrimSpace(x), 64)
   500  		if err != nil {
   501  			return 0, errCastFailure("could not parse as float")
   502  		}
   503  		return f, nil
   504  	case []byte:
   505  		f, err := strconv.ParseFloat(strings.TrimSpace(string(x)), 64)
   506  		if err != nil {
   507  			return 0, errCastFailure("could not parse as float")
   508  		}
   509  		return f, nil
   510  	default:
   511  		return 0, errUnsupportedCast(v.GetTypeString(), castFloat)
   512  	}
   513  }
   514  
   515  func stringCast(v *Value) (string, error) {
   516  	switch x := v.value.(type) {
   517  	case float64:
   518  		return fmt.Sprintf("%v", x), nil
   519  	case int64:
   520  		return fmt.Sprintf("%v", x), nil
   521  	case string:
   522  		return x, nil
   523  	case []byte:
   524  		return string(x), nil
   525  	case bool:
   526  		return fmt.Sprintf("%v", x), nil
   527  	case nil:
   528  		// FIXME: verify this case is correct
   529  		return "NULL", nil
   530  	}
   531  	// This does not happen
   532  	return "", errCastFailure(fmt.Sprintf("cannot cast %v to string type", v.GetTypeString()))
   533  }
   534  
   535  func timestampCast(v *Value) (t time.Time, _ error) {
   536  	switch x := v.value.(type) {
   537  	case string:
   538  		return parseSQLTimestamp(x)
   539  	case []byte:
   540  		return parseSQLTimestamp(string(x))
   541  	case time.Time:
   542  		return x, nil
   543  	default:
   544  		return t, errCastFailure(fmt.Sprintf("cannot cast %v to Timestamp type", v.GetTypeString()))
   545  	}
   546  }
   547  
   548  func boolCast(v *Value) (b bool, _ error) {
   549  	sToB := func(s string) (bool, error) {
   550  		switch s {
   551  		case "true":
   552  			return true, nil
   553  		case "false":
   554  			return false, nil
   555  		default:
   556  			return false, errCastFailure("cannot cast to Bool")
   557  		}
   558  	}
   559  	switch x := v.value.(type) {
   560  	case bool:
   561  		return x, nil
   562  	case string:
   563  		return sToB(strings.ToLower(x))
   564  	case []byte:
   565  		return sToB(strings.ToLower(string(x)))
   566  	default:
   567  		return false, errCastFailure("cannot cast %v to Bool")
   568  	}
   569  }