github.com/minio/minio@v0.0.0-20240328213742-3f72439b8a27/internal/s3select/sql/funceval.go (about) 1 // Copyright (c) 2015-2021 MinIO, Inc. 2 // 3 // This file is part of MinIO Object Storage stack 4 // 5 // This program is free software: you can redistribute it and/or modify 6 // it under the terms of the GNU Affero General Public License as published by 7 // the Free Software Foundation, either version 3 of the License, or 8 // (at your option) any later version. 9 // 10 // This program is distributed in the hope that it will be useful 11 // but WITHOUT ANY WARRANTY; without even the implied warranty of 12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 // GNU Affero General Public License for more details. 14 // 15 // You should have received a copy of the GNU Affero General Public License 16 // along with this program. If not, see <http://www.gnu.org/licenses/>. 17 18 package sql 19 20 import ( 21 "errors" 22 "fmt" 23 "strconv" 24 "strings" 25 "time" 26 ) 27 28 // FuncName - SQL function name. 29 type FuncName string 30 31 // SQL Function name constants 32 const ( 33 // Conditionals 34 sqlFnCoalesce FuncName = "COALESCE" 35 sqlFnNullIf FuncName = "NULLIF" 36 37 // Conversion 38 sqlFnCast FuncName = "CAST" 39 40 // Date and time 41 sqlFnDateAdd FuncName = "DATE_ADD" 42 sqlFnDateDiff FuncName = "DATE_DIFF" 43 sqlFnExtract FuncName = "EXTRACT" 44 sqlFnToString FuncName = "TO_STRING" 45 sqlFnToTimestamp FuncName = "TO_TIMESTAMP" 46 sqlFnUTCNow FuncName = "UTCNOW" 47 48 // String 49 sqlFnCharLength FuncName = "CHAR_LENGTH" 50 sqlFnCharacterLength FuncName = "CHARACTER_LENGTH" 51 sqlFnLower FuncName = "LOWER" 52 sqlFnSubstring FuncName = "SUBSTRING" 53 sqlFnTrim FuncName = "TRIM" 54 sqlFnUpper FuncName = "UPPER" 55 ) 56 57 var ( 58 errUnimplementedCast = errors.New("This cast not yet implemented") 59 errNonStringTrimArg = errors.New("TRIM() received a non-string argument") 60 errNonTimestampArg = errors.New("Expected a timestamp argument") 61 ) 62 63 func (e *FuncExpr) getFunctionName() FuncName { 64 switch { 65 case e.SFunc != nil: 66 return FuncName(strings.ToUpper(e.SFunc.FunctionName)) 67 case e.Count != nil: 68 return aggFnCount 69 case e.Cast != nil: 70 return sqlFnCast 71 case e.Substring != nil: 72 return sqlFnSubstring 73 case e.Extract != nil: 74 return sqlFnExtract 75 case e.Trim != nil: 76 return sqlFnTrim 77 case e.DateAdd != nil: 78 return sqlFnDateAdd 79 case e.DateDiff != nil: 80 return sqlFnDateDiff 81 default: 82 return "" 83 } 84 } 85 86 // evalSQLFnNode assumes that the FuncExpr is not an aggregation 87 // function. 88 func (e *FuncExpr) evalSQLFnNode(r Record, tableAlias string) (res *Value, err error) { 89 // Handle functions that have phrase arguments 90 switch e.getFunctionName() { 91 case sqlFnCast: 92 expr := e.Cast.Expr 93 res, err = expr.castTo(r, strings.ToUpper(e.Cast.CastType), tableAlias) 94 return 95 96 case sqlFnSubstring: 97 return handleSQLSubstring(r, e.Substring, tableAlias) 98 99 case sqlFnExtract: 100 return handleSQLExtract(r, e.Extract, tableAlias) 101 102 case sqlFnTrim: 103 return handleSQLTrim(r, e.Trim, tableAlias) 104 105 case sqlFnDateAdd: 106 return handleDateAdd(r, e.DateAdd, tableAlias) 107 108 case sqlFnDateDiff: 109 return handleDateDiff(r, e.DateDiff, tableAlias) 110 111 } 112 113 // For all simple argument functions, we evaluate the arguments here 114 argVals := make([]*Value, len(e.SFunc.ArgsList)) 115 for i, arg := range e.SFunc.ArgsList { 116 argVals[i], err = arg.evalNode(r, tableAlias) 117 if err != nil { 118 return nil, err 119 } 120 } 121 122 switch e.getFunctionName() { 123 case sqlFnCoalesce: 124 return coalesce(argVals) 125 126 case sqlFnNullIf: 127 return nullif(argVals[0], argVals[1]) 128 129 case sqlFnCharLength, sqlFnCharacterLength: 130 return charlen(argVals[0]) 131 132 case sqlFnLower: 133 return lowerCase(argVals[0]) 134 135 case sqlFnUpper: 136 return upperCase(argVals[0]) 137 138 case sqlFnUTCNow: 139 return handleUTCNow() 140 141 case sqlFnToString, sqlFnToTimestamp: 142 // TODO: implement 143 fallthrough 144 145 default: 146 return nil, errNotImplemented 147 } 148 } 149 150 func coalesce(args []*Value) (res *Value, err error) { 151 for _, arg := range args { 152 if arg.IsNull() { 153 continue 154 } 155 return arg, nil 156 } 157 return FromNull(), nil 158 } 159 160 func nullif(v1, v2 *Value) (res *Value, err error) { 161 // Handle Null cases 162 if v1.IsNull() || v2.IsNull() { 163 return v1, nil 164 } 165 166 err = inferTypesForCmp(v1, v2) 167 if err != nil { 168 return nil, err 169 } 170 171 atleastOneNumeric := v1.isNumeric() || v2.isNumeric() 172 bothNumeric := v1.isNumeric() && v2.isNumeric() 173 if atleastOneNumeric || !bothNumeric { 174 return v1, nil 175 } 176 177 if v1.SameTypeAs(*v2) { 178 return v1, nil 179 } 180 181 cmpResult, cmpErr := v1.compareOp(opEq, v2) 182 if cmpErr != nil { 183 return nil, cmpErr 184 } 185 186 if cmpResult { 187 return FromNull(), nil 188 } 189 190 return v1, nil 191 } 192 193 func charlen(v *Value) (*Value, error) { 194 inferTypeAsString(v) 195 s, ok := v.ToString() 196 if !ok { 197 err := fmt.Errorf("%s/%s expects a string argument", sqlFnCharLength, sqlFnCharacterLength) 198 return nil, errIncorrectSQLFunctionArgumentType(err) 199 } 200 return FromInt(int64(len([]rune(s)))), nil 201 } 202 203 func lowerCase(v *Value) (*Value, error) { 204 inferTypeAsString(v) 205 s, ok := v.ToString() 206 if !ok { 207 err := fmt.Errorf("%s expects a string argument", sqlFnLower) 208 return nil, errIncorrectSQLFunctionArgumentType(err) 209 } 210 return FromString(strings.ToLower(s)), nil 211 } 212 213 func upperCase(v *Value) (*Value, error) { 214 inferTypeAsString(v) 215 s, ok := v.ToString() 216 if !ok { 217 err := fmt.Errorf("%s expects a string argument", sqlFnUpper) 218 return nil, errIncorrectSQLFunctionArgumentType(err) 219 } 220 return FromString(strings.ToUpper(s)), nil 221 } 222 223 func handleDateAdd(r Record, d *DateAddFunc, tableAlias string) (*Value, error) { 224 q, err := d.Quantity.evalNode(r, tableAlias) 225 if err != nil { 226 return nil, err 227 } 228 inferTypeForArithOp(q) 229 qty, ok := q.ToFloat() 230 if !ok { 231 return nil, fmt.Errorf("QUANTITY must be a numeric argument to %s()", sqlFnDateAdd) 232 } 233 234 ts, err := d.Timestamp.evalNode(r, tableAlias) 235 if err != nil { 236 return nil, err 237 } 238 if err = inferTypeAsTimestamp(ts); err != nil { 239 return nil, err 240 } 241 t, ok := ts.ToTimestamp() 242 if !ok { 243 return nil, fmt.Errorf("%s() expects a timestamp argument", sqlFnDateAdd) 244 } 245 246 return dateAdd(strings.ToUpper(d.DatePart), qty, t) 247 } 248 249 func handleDateDiff(r Record, d *DateDiffFunc, tableAlias string) (*Value, error) { 250 tval1, err := d.Timestamp1.evalNode(r, tableAlias) 251 if err != nil { 252 return nil, err 253 } 254 if err = inferTypeAsTimestamp(tval1); err != nil { 255 return nil, err 256 } 257 ts1, ok := tval1.ToTimestamp() 258 if !ok { 259 return nil, fmt.Errorf("%s() expects two timestamp arguments", sqlFnDateDiff) 260 } 261 262 tval2, err := d.Timestamp2.evalNode(r, tableAlias) 263 if err != nil { 264 return nil, err 265 } 266 if err = inferTypeAsTimestamp(tval2); err != nil { 267 return nil, err 268 } 269 ts2, ok := tval2.ToTimestamp() 270 if !ok { 271 return nil, fmt.Errorf("%s() expects two timestamp arguments", sqlFnDateDiff) 272 } 273 274 return dateDiff(strings.ToUpper(d.DatePart), ts1, ts2) 275 } 276 277 func handleUTCNow() (*Value, error) { 278 return FromTimestamp(time.Now().UTC()), nil 279 } 280 281 func handleSQLSubstring(r Record, e *SubstringFunc, tableAlias string) (val *Value, err error) { 282 // Both forms `SUBSTRING('abc' FROM 2 FOR 1)` and 283 // SUBSTRING('abc', 2, 1) are supported. 284 285 // Evaluate the string argument 286 v1, err := e.Expr.evalNode(r, tableAlias) 287 if err != nil { 288 return nil, err 289 } 290 inferTypeAsString(v1) 291 s, ok := v1.ToString() 292 if !ok { 293 err := fmt.Errorf("Incorrect argument type passed to %s", sqlFnSubstring) 294 return nil, errIncorrectSQLFunctionArgumentType(err) 295 } 296 297 // Assemble other arguments 298 arg2, arg3 := e.From, e.For 299 // Check if the second form of substring is being used 300 if e.From == nil { 301 arg2, arg3 = e.Arg2, e.Arg3 302 } 303 304 // Evaluate the FROM argument 305 v2, err := arg2.evalNode(r, tableAlias) 306 if err != nil { 307 return nil, err 308 } 309 inferTypeForArithOp(v2) 310 startIdx, ok := v2.ToInt() 311 if !ok { 312 err := fmt.Errorf("Incorrect type for start index argument in %s", sqlFnSubstring) 313 return nil, errIncorrectSQLFunctionArgumentType(err) 314 } 315 316 length := -1 317 // Evaluate the optional FOR argument 318 if arg3 != nil { 319 v3, err := arg3.evalNode(r, tableAlias) 320 if err != nil { 321 return nil, err 322 } 323 inferTypeForArithOp(v3) 324 lenInt, ok := v3.ToInt() 325 if !ok { 326 err := fmt.Errorf("Incorrect type for length argument in %s", sqlFnSubstring) 327 return nil, errIncorrectSQLFunctionArgumentType(err) 328 } 329 length = int(lenInt) 330 if length < 0 { 331 err := fmt.Errorf("Negative length argument in %s", sqlFnSubstring) 332 return nil, errIncorrectSQLFunctionArgumentType(err) 333 } 334 } 335 336 res, err := evalSQLSubstring(s, int(startIdx), length) 337 return FromString(res), err 338 } 339 340 func handleSQLTrim(r Record, e *TrimFunc, tableAlias string) (res *Value, err error) { 341 chars := "" 342 ok := false 343 if e.TrimChars != nil { 344 charsV, cerr := e.TrimChars.evalNode(r, tableAlias) 345 if cerr != nil { 346 return nil, cerr 347 } 348 inferTypeAsString(charsV) 349 chars, ok = charsV.ToString() 350 if !ok { 351 return nil, errNonStringTrimArg 352 } 353 } 354 355 fromV, ferr := e.TrimFrom.evalNode(r, tableAlias) 356 if ferr != nil { 357 return nil, ferr 358 } 359 inferTypeAsString(fromV) 360 from, ok := fromV.ToString() 361 if !ok { 362 return nil, errNonStringTrimArg 363 } 364 365 result, terr := evalSQLTrim(e.TrimWhere, chars, from) 366 if terr != nil { 367 return nil, terr 368 } 369 return FromString(result), nil 370 } 371 372 func handleSQLExtract(r Record, e *ExtractFunc, tableAlias string) (res *Value, err error) { 373 timeVal, verr := e.From.evalNode(r, tableAlias) 374 if verr != nil { 375 return nil, verr 376 } 377 378 if err = inferTypeAsTimestamp(timeVal); err != nil { 379 return nil, err 380 } 381 382 t, ok := timeVal.ToTimestamp() 383 if !ok { 384 return nil, errNonTimestampArg 385 } 386 387 return extract(strings.ToUpper(e.Timeword), t) 388 } 389 390 func errUnsupportedCast(fromType, toType string) error { 391 return fmt.Errorf("Cannot cast from %v to %v", fromType, toType) 392 } 393 394 func errCastFailure(msg string) error { 395 return fmt.Errorf("Error casting: %s", msg) 396 } 397 398 // Allowed cast types 399 const ( 400 castBool = "BOOL" 401 castInt = "INT" 402 castInteger = "INTEGER" 403 castString = "STRING" 404 castFloat = "FLOAT" 405 castDecimal = "DECIMAL" 406 castNumeric = "NUMERIC" 407 castTimestamp = "TIMESTAMP" 408 ) 409 410 func (e *Expression) castTo(r Record, castType string, tableAlias string) (res *Value, err error) { 411 v, err := e.evalNode(r, tableAlias) 412 if err != nil { 413 return nil, err 414 } 415 416 switch castType { 417 case castInt, castInteger: 418 i, err := intCast(v) 419 return FromInt(i), err 420 421 case castFloat: 422 f, err := floatCast(v) 423 return FromFloat(f), err 424 425 case castString: 426 s, err := stringCast(v) 427 return FromString(s), err 428 429 case castTimestamp: 430 t, err := timestampCast(v) 431 return FromTimestamp(t), err 432 433 case castBool: 434 b, err := boolCast(v) 435 return FromBool(b), err 436 437 case castDecimal, castNumeric: 438 fallthrough 439 440 default: 441 return nil, errUnimplementedCast 442 } 443 } 444 445 func intCast(v *Value) (int64, error) { 446 // This conversion truncates floating point numbers to 447 // integer. 448 strToInt := func(s string) (int64, bool) { 449 i, errI := strconv.ParseInt(s, 10, 64) 450 if errI == nil { 451 return i, true 452 } 453 f, errF := strconv.ParseFloat(s, 64) 454 if errF == nil { 455 return int64(f), true 456 } 457 return 0, false 458 } 459 460 switch x := v.value.(type) { 461 case float64: 462 // Truncate fractional part 463 return int64(x), nil 464 case int64: 465 return x, nil 466 case string: 467 // Parse as number, truncate floating point if 468 // needed. 469 // String might contain trimming spaces, which 470 // needs to be trimmed. 471 res, ok := strToInt(strings.TrimSpace(x)) 472 if !ok { 473 return 0, errCastFailure("could not parse as int") 474 } 475 return res, nil 476 case []byte: 477 // Parse as number, truncate floating point if 478 // needed. 479 // String might contain trimming spaces, which 480 // needs to be trimmed. 481 res, ok := strToInt(strings.TrimSpace(string(x))) 482 if !ok { 483 return 0, errCastFailure("could not parse as int") 484 } 485 return res, nil 486 487 default: 488 return 0, errUnsupportedCast(v.GetTypeString(), castInt) 489 } 490 } 491 492 func floatCast(v *Value) (float64, error) { 493 switch x := v.value.(type) { 494 case float64: 495 return x, nil 496 case int64: 497 return float64(x), nil 498 case string: 499 f, err := strconv.ParseFloat(strings.TrimSpace(x), 64) 500 if err != nil { 501 return 0, errCastFailure("could not parse as float") 502 } 503 return f, nil 504 case []byte: 505 f, err := strconv.ParseFloat(strings.TrimSpace(string(x)), 64) 506 if err != nil { 507 return 0, errCastFailure("could not parse as float") 508 } 509 return f, nil 510 default: 511 return 0, errUnsupportedCast(v.GetTypeString(), castFloat) 512 } 513 } 514 515 func stringCast(v *Value) (string, error) { 516 switch x := v.value.(type) { 517 case float64: 518 return fmt.Sprintf("%v", x), nil 519 case int64: 520 return fmt.Sprintf("%v", x), nil 521 case string: 522 return x, nil 523 case []byte: 524 return string(x), nil 525 case bool: 526 return fmt.Sprintf("%v", x), nil 527 case nil: 528 // FIXME: verify this case is correct 529 return "NULL", nil 530 } 531 // This does not happen 532 return "", errCastFailure(fmt.Sprintf("cannot cast %v to string type", v.GetTypeString())) 533 } 534 535 func timestampCast(v *Value) (t time.Time, _ error) { 536 switch x := v.value.(type) { 537 case string: 538 return parseSQLTimestamp(x) 539 case []byte: 540 return parseSQLTimestamp(string(x)) 541 case time.Time: 542 return x, nil 543 default: 544 return t, errCastFailure(fmt.Sprintf("cannot cast %v to Timestamp type", v.GetTypeString())) 545 } 546 } 547 548 func boolCast(v *Value) (b bool, _ error) { 549 sToB := func(s string) (bool, error) { 550 switch s { 551 case "true": 552 return true, nil 553 case "false": 554 return false, nil 555 default: 556 return false, errCastFailure("cannot cast to Bool") 557 } 558 } 559 switch x := v.value.(type) { 560 case bool: 561 return x, nil 562 case string: 563 return sToB(strings.ToLower(x)) 564 case []byte: 565 return sToB(strings.ToLower(string(x))) 566 default: 567 return false, errCastFailure("cannot cast %v to Bool") 568 } 569 }