github.com/dolthub/go-mysql-server@v0.18.0/sql/planbuilder/scalar.go (about) 1 // Copyright 2023 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package planbuilder 16 17 import ( 18 "encoding/hex" 19 "fmt" 20 "strconv" 21 "strings" 22 23 ast "github.com/dolthub/vitess/go/vt/sqlparser" 24 25 "github.com/dolthub/go-mysql-server/sql" 26 "github.com/dolthub/go-mysql-server/sql/encodings" 27 "github.com/dolthub/go-mysql-server/sql/expression" 28 "github.com/dolthub/go-mysql-server/sql/expression/function" 29 "github.com/dolthub/go-mysql-server/sql/expression/function/json" 30 "github.com/dolthub/go-mysql-server/sql/fulltext" 31 "github.com/dolthub/go-mysql-server/sql/plan" 32 "github.com/dolthub/go-mysql-server/sql/types" 33 ) 34 35 func (b *Builder) buildWhere(inScope *scope, where *ast.Where) { 36 if where == nil { 37 return 38 } 39 filter := b.buildScalar(inScope, where.Expr) 40 filterNode := plan.NewFilter(filter, inScope.node) 41 inScope.node = filterNode 42 } 43 44 func (b *Builder) buildScalar(inScope *scope, e ast.Expr) (ex sql.Expression) { 45 defer func() { 46 if !(b.bindCtx == nil || b.bindCtx.resolveOnly) { 47 return 48 } 49 50 if be, ok := ex.(expression.BinaryExpression); ok { 51 left := be.Left() 52 right := be.Right() 53 if leftBindVar, ok := left.(*expression.BindVar); ok { 54 if typ, ok := hasColumnType(right); ok { 55 leftBindVar.Typ = typ 56 left = leftBindVar 57 } 58 } else if rightBindVar, ok := right.(*expression.BindVar); ok { 59 if typ, ok := hasColumnType(left); ok { 60 rightBindVar.Typ = typ 61 right = rightBindVar 62 } 63 } 64 ex, _ = be.WithChildren(left, right) 65 } 66 }() 67 68 switch v := e.(type) { 69 case *ast.Default: 70 return expression.WrapExpression(expression.NewDefaultColumn(v.ColName)) 71 case *ast.SubstrExpr: 72 var name sql.Expression 73 if v.Name != nil { 74 name = b.buildScalar(inScope, v.Name) 75 } else { 76 name = b.buildScalar(inScope, v.StrVal) 77 } 78 start := b.buildScalar(inScope, v.From) 79 80 if v.To == nil { 81 return &function.Substring{Str: name, Start: start} 82 } 83 len := b.buildScalar(inScope, v.To) 84 return &function.Substring{Str: name, Start: start, Len: len} 85 case *ast.TrimExpr: 86 pat := b.buildScalar(inScope, v.Pattern) 87 str := b.buildScalar(inScope, v.Str) 88 return function.NewTrim(str, pat, v.Dir) 89 case *ast.ComparisonExpr: 90 return b.buildComparison(inScope, v) 91 case *ast.IsExpr: 92 return b.buildIsExprToExpression(inScope, v) 93 case *ast.NotExpr: 94 c := b.buildScalar(inScope, v.Expr) 95 return expression.NewNot(c) 96 case *ast.SQLVal: 97 return b.ConvertVal(v) 98 case ast.BoolVal: 99 return expression.NewLiteral(bool(v), types.Boolean) 100 case *ast.NullVal: 101 return expression.NewLiteral(nil, types.Null) 102 case *ast.ColName: 103 dbName := strings.ToLower(v.Qualifier.Qualifier.String()) 104 tblName := strings.ToLower(v.Qualifier.Name.String()) 105 colName := strings.ToLower(v.Name.String()) 106 c, ok := inScope.resolveColumn(dbName, tblName, colName, true, false) 107 if !ok { 108 sysVar, scope, ok := b.buildSysVar(v, ast.SetScope_None) 109 if ok { 110 return sysVar 111 } 112 var err error 113 if scope == ast.SetScope_User { 114 err = sql.ErrUnknownUserVariable.New(colName) 115 } else if scope == ast.SetScope_Persist || scope == ast.SetScope_PersistOnly { 116 err = sql.ErrUnknownUserVariable.New(colName) 117 } else if scope == ast.SetScope_Global || scope == ast.SetScope_Session { 118 err = sql.ErrUnknownSystemVariable.New(colName) 119 } else if tblName != "" && !inScope.hasTable(tblName) { 120 err = sql.ErrTableNotFound.New(tblName) 121 } else if tblName != "" { 122 err = sql.ErrTableColumnNotFound.New(tblName, colName) 123 } else { 124 err = sql.ErrColumnNotFound.New(v) 125 } 126 b.handleErr(err) 127 } 128 c = c.withOriginal(v.Name.String()) 129 return c.scalarGf() 130 case *ast.FuncExpr: 131 name := v.Name.Lowered() 132 133 if isAggregateFunc(name) && v.Over == nil { 134 // TODO this assumes aggregate is in the same scope 135 // also need to avoid nested aggregates 136 return b.buildAggregateFunc(inScope, name, v) 137 } else if isWindowFunc(name) { 138 return b.buildWindowFunc(inScope, name, v, (*ast.WindowDef)(v.Over)) 139 } 140 141 f, err := b.cat.Function(b.ctx, name) 142 if err != nil { 143 b.handleErr(err) 144 } 145 146 args := make([]sql.Expression, len(v.Exprs)) 147 for i, e := range v.Exprs { 148 args[i] = b.selectExprToExpression(inScope, e) 149 } 150 151 if name == "json_value" { 152 if len(args) == 3 { 153 args[2] = b.getJsonValueTypeLiteral(args[2]) 154 } 155 } 156 157 rf, err := f.NewInstance(args) 158 if err != nil { 159 b.handleErr(err) 160 } 161 162 // NOTE: Not all aggregate functions support DISTINCT. Fortunately, the vitess parser will throw 163 // errors for when DISTINCT is used on aggregate functions that don't support DISTINCT. 164 if v.Distinct { 165 if len(args) != 1 { 166 return nil 167 } 168 args[0] = expression.NewDistinctExpression(args[0]) 169 } 170 171 if _, ok := rf.(sql.NonDeterministicExpression); ok && inScope.nearestSubquery() != nil { 172 inScope.nearestSubquery().markVolatile() 173 } 174 175 return rf 176 177 case *ast.GroupConcatExpr: 178 // TODO this is an aggregation 179 return b.buildGroupConcat(inScope, v) 180 case *ast.ParenExpr: 181 return b.buildScalar(inScope, v.Expr) 182 case *ast.AndExpr: 183 lhs := b.buildScalar(inScope, v.Left) 184 rhs := b.buildScalar(inScope, v.Right) 185 return expression.NewAnd(lhs, rhs) 186 case *ast.OrExpr: 187 lhs := b.buildScalar(inScope, v.Left) 188 rhs := b.buildScalar(inScope, v.Right) 189 return expression.NewOr(lhs, rhs) 190 case *ast.XorExpr: 191 lhs := b.buildScalar(inScope, v.Left) 192 rhs := b.buildScalar(inScope, v.Right) 193 return expression.NewXor(lhs, rhs) 194 case *ast.ConvertUsingExpr: 195 expr := b.buildScalar(inScope, v.Expr) 196 charset, err := sql.ParseCharacterSet(v.Type) 197 if err != nil { 198 b.handleErr(err) 199 } 200 return expression.NewConvertUsing(expr, charset) 201 case *ast.CharExpr: 202 args := make([]sql.Expression, len(v.Exprs)) 203 for i, e := range v.Exprs { 204 args[i] = b.selectExprToExpression(inScope, e) 205 } 206 207 f, err := function.NewChar(args...) 208 if err != nil { 209 b.handleErr(err) 210 } 211 212 collId, err := sql.ParseCollation(&v.Type, nil, true) 213 if err != nil { 214 b.handleErr(err) 215 } 216 217 charFunc := f.(*function.Char) 218 charFunc.Collation = collId 219 return charFunc 220 case *ast.ConvertExpr: 221 var err error 222 typeLength := 0 223 if v.Type.Length != nil { 224 // TODO move to vitess 225 typeLength, err = strconv.Atoi(v.Type.Length.String()) 226 if err != nil { 227 b.handleErr(err) 228 } 229 } 230 231 typeScale := 0 232 if v.Type.Scale != nil { 233 // TODO move to vitess 234 typeScale, err = strconv.Atoi(v.Type.Scale.String()) 235 if err != nil { 236 b.handleErr(err) 237 } 238 } 239 expr := b.buildScalar(inScope, v.Expr) 240 ret, err := b.f.buildConvert(expr, v.Type.Type, typeLength, typeScale) 241 if err != nil { 242 b.handleErr(err) 243 } 244 return ret 245 case ast.InjectedExpr: 246 resolvedChildren := make([]any, len(v.Children)) 247 for i, child := range v.Children { 248 resolvedChildren[i] = b.buildScalar(inScope, child) 249 } 250 expr, err := v.Expression.WithResolvedChildren(resolvedChildren) 251 if err != nil { 252 b.handleErr(err) 253 } 254 if sqlExpr, ok := expr.(sql.Expression); ok { 255 return sqlExpr 256 } 257 b.handleErr(fmt.Errorf("Injected expression does not resolve to a valid expression")) 258 return nil 259 case *ast.RangeCond: 260 val := b.buildScalar(inScope, v.Left) 261 lower := b.buildScalar(inScope, v.From) 262 upper := b.buildScalar(inScope, v.To) 263 264 switch strings.ToLower(v.Operator) { 265 case ast.BetweenStr: 266 return expression.NewBetween(val, lower, upper) 267 case ast.NotBetweenStr: 268 return expression.NewNot(expression.NewBetween(val, lower, upper)) 269 default: 270 return nil 271 } 272 case ast.ValTuple: 273 var exprs = make([]sql.Expression, len(v)) 274 for i, e := range v { 275 expr := b.buildScalar(inScope, e) 276 exprs[i] = expr 277 } 278 return expression.NewTuple(exprs...) 279 280 case *ast.BinaryExpr: 281 return b.buildBinaryScalar(inScope, v) 282 case *ast.UnaryExpr: 283 return b.buildUnaryScalar(inScope, v) 284 case *ast.Subquery: 285 sqScope := inScope.pushSubquery() 286 selectString := ast.String(v.Select) 287 selScope := b.buildSelectStmt(sqScope, v.Select) 288 // TODO: get the original select statement, not the reconstruction 289 sq := plan.NewSubquery(selScope.node, selectString) 290 sq = sq.WithCorrelated(sqScope.correlated()) 291 if b.TriggerCtx().Active { 292 sq = sq.WithVolatile() 293 } 294 return sq 295 case *ast.CaseExpr: 296 return b.buildCaseExpr(inScope, v) 297 case *ast.IntervalExpr: 298 e := b.buildScalar(inScope, v.Expr) 299 return expression.NewInterval(e, v.Unit) 300 case *ast.CollateExpr: 301 // handleCollateExpr is meant to handle generic text-returning expressions that should be reinterpreted as a different collation. 302 innerExpr := b.buildScalar(inScope, v.Expr) 303 //TODO: rename this from Charset to Collation 304 collation, err := sql.ParseCollation(nil, &v.Charset, false) 305 if err != nil { 306 b.handleErr(err) 307 } 308 // If we're collating a string literal, we check that the charset and collation match now. Other string sources 309 // (such as from tables) will have their own charset, which we won't know until after the parsing stage. 310 charSet := b.ctx.GetCharacterSet() 311 if _, isLiteral := innerExpr.(*expression.Literal); isLiteral && collation.CharacterSet() != charSet { 312 b.handleErr(sql.ErrCollationInvalidForCharSet.New(collation.Name(), charSet.Name())) 313 } 314 return expression.NewCollatedExpression(innerExpr, collation) 315 case *ast.ValuesFuncExpr: 316 if b.insertActive { 317 if v.Name.Qualifier.Name.String() == "" { 318 v.Name.Qualifier.Name = ast.NewTableIdent(OnDupValuesPrefix) 319 } 320 dbName := strings.ToLower(v.Name.Qualifier.Qualifier.String()) 321 tblName := strings.ToLower(v.Name.Qualifier.Name.String()) 322 colName := strings.ToLower(v.Name.Name.String()) 323 col, ok := inScope.resolveColumn(dbName, tblName, colName, false, false) 324 if !ok { 325 err := fmt.Errorf("expected ON DUPLICATE KEY ... VALUES() to reference a column, found: %s", v.Name.String()) 326 b.handleErr(err) 327 } 328 return col.scalarGf() 329 } else { 330 col := b.buildScalar(inScope, v.Name) 331 fn, err := b.cat.Function(b.ctx, "values") 332 if err != nil { 333 b.handleErr(err) 334 } 335 values, err := fn.NewInstance([]sql.Expression{col}) 336 if err != nil { 337 b.handleErr(err) 338 } 339 return values 340 } 341 case *ast.ExistsExpr: 342 sqScope := inScope.push() 343 sqScope.initSubquery() 344 selScope := b.buildSelectStmt(sqScope, v.Subquery.Select) 345 selectString := ast.String(v.Subquery.Select) 346 sq := plan.NewSubquery(selScope.node, selectString) 347 sq = sq.WithCorrelated(sqScope.correlated()) 348 return plan.NewExistsSubquery(sq) 349 case *ast.TimestampFuncExpr: 350 var ( 351 unit sql.Expression 352 expr1 sql.Expression 353 expr2 sql.Expression 354 ) 355 356 unit = expression.NewLiteral(v.Unit, types.LongText) 357 expr1 = b.buildScalar(inScope, v.Expr1) 358 expr2 = b.buildScalar(inScope, v.Expr2) 359 360 if v.Name == "timestampdiff" { 361 return function.NewTimestampDiff(unit, expr1, expr2) 362 } else if v.Name == "timestampadd" { 363 return nil 364 } 365 return nil 366 case *ast.ExtractFuncExpr: 367 var unit sql.Expression = expression.NewLiteral(strings.ToUpper(v.Unit), types.LongText) 368 expr := b.buildScalar(inScope, v.Expr) 369 return function.NewExtract(unit, expr) 370 case *ast.MatchExpr: 371 return b.buildMatchAgainst(inScope, v) 372 default: 373 b.handleErr(sql.ErrUnsupportedSyntax.New(ast.String(e))) 374 } 375 return nil 376 } 377 378 // getJsonValueTypeLiteral converts a type coercion string into a literal 379 // expression with the zero type of the coercion (see json_value function). 380 func (b *Builder) getJsonValueTypeLiteral(e sql.Expression) sql.Expression { 381 typLit, ok := e.(*expression.Literal) 382 if !ok { 383 err := fmt.Errorf("invalid json_value coercion type: %s", e) 384 b.handleErr(err) 385 } 386 convStr, _, err := types.LongText.Convert(typLit.Value()) 387 if err != nil { 388 err := fmt.Errorf("invalid json_value coercion type: %s; %s", typLit.Value(), err.Error()) 389 b.handleErr(err) 390 } 391 var typ sql.Type 392 switch strings.ToLower(convStr.(string)) { 393 case "float": 394 typ = types.Float32 395 case "double", "decimal": 396 typ = types.Float64 397 case "signed": 398 typ = types.Int64 399 case "unsigned": 400 typ = types.Uint64 401 case "char": 402 typ = types.Text 403 case "json": 404 typ = types.JSON 405 case "time": 406 typ = types.Time 407 case "datetime": 408 typ = types.Datetime 409 case "date": 410 typ = types.Date 411 case "year": 412 typ = types.Year 413 default: 414 err := fmt.Errorf("invalid type for json_value: %s", convStr) 415 b.handleErr(err) 416 } 417 return expression.NewLiteral(typ.Zero(), typ) 418 } 419 420 func (b *Builder) buildUnaryScalar(inScope *scope, e *ast.UnaryExpr) sql.Expression { 421 switch strings.ToLower(e.Operator) { 422 case ast.MinusStr: 423 expr := b.buildScalar(inScope, e.Expr) 424 return expression.NewUnaryMinus(expr) 425 case ast.PlusStr: 426 // Unary plus expressions do nothing (do not turn the expression positive). Just return the underlying expressio return b.buildScalar(inScope, e.Expr) 427 return b.buildScalar(inScope, e.Expr) 428 case ast.BangStr: 429 c := b.buildScalar(inScope, e.Expr) 430 return expression.NewNot(c) 431 case ast.BinaryStr: 432 c := b.buildScalar(inScope, e.Expr) 433 return expression.NewBinary(c) 434 default: 435 lowerOperator := strings.TrimSpace(strings.ToLower(e.Operator)) 436 if strings.HasPrefix(lowerOperator, "_") { 437 // This is a character set introducer, so we need to decode the string to our internal encoding (`utf8mb4`) 438 charSet, err := sql.ParseCharacterSet(lowerOperator[1:]) 439 if err != nil { 440 b.handleErr(err) 441 } 442 if charSet.Encoder() == nil { 443 err := sql.ErrUnsupportedFeature.New("unsupported character set: " + charSet.Name()) 444 b.handleErr(err) 445 } 446 447 // Due to how vitess orders expressions, COLLATE is a child rather than a parent, so we need to handle it in a special way 448 collation := charSet.DefaultCollation() 449 if collateExpr, ok := e.Expr.(*ast.CollateExpr); ok { 450 // We extract the expression out of CollateExpr as we're only concerned about the collation string 451 e.Expr = collateExpr.Expr 452 // TODO: rename this from Charset to Collation 453 collation, err = sql.ParseCollation(nil, &collateExpr.Charset, false) 454 if err != nil { 455 b.handleErr(err) 456 } 457 if collation.CharacterSet() != charSet { 458 err := sql.ErrCollationInvalidForCharSet.New(collation.Name(), charSet.Name()) 459 b.handleErr(err) 460 } 461 } 462 463 // Character set introducers only work on string literals 464 expr := b.buildScalar(inScope, e.Expr) 465 if _, ok := expr.(*expression.Literal); !ok || !types.IsText(expr.Type()) { 466 err := sql.ErrCharSetIntroducer.New() 467 b.handleErr(err) 468 } 469 literal, _ := expr.Eval(b.ctx, nil) 470 471 // Internally all strings are `utf8mb4`, so we need to decode the string (which applies the introducer) 472 if strLiteral, ok := literal.(string); ok { 473 decodedLiteral, ok := charSet.Encoder().Decode(encodings.StringToBytes(strLiteral)) 474 if !ok { 475 err := sql.ErrCharSetInvalidString.New(charSet.Name(), strLiteral) 476 b.handleErr(err) 477 } 478 return expression.NewLiteral(encodings.BytesToString(decodedLiteral), types.CreateLongText(collation)) 479 } else if byteLiteral, ok := literal.([]byte); ok { 480 decodedLiteral, ok := charSet.Encoder().Decode(byteLiteral) 481 if !ok { 482 err := sql.ErrCharSetInvalidString.New(charSet.Name(), strLiteral) 483 b.handleErr(err) 484 } 485 return expression.NewLiteral(decodedLiteral, types.CreateLongText(collation)) 486 } else { 487 // Should not be possible 488 err := fmt.Errorf("expression literal returned type `%s` but literal value had type `%T`", 489 expr.Type().String(), literal) 490 b.handleErr(err) 491 } 492 } 493 err := sql.ErrUnsupportedFeature.New("unary operator: " + e.Operator) 494 b.handleErr(err) 495 } 496 return nil 497 } 498 499 func (b *Builder) buildBinaryScalar(inScope *scope, be *ast.BinaryExpr) sql.Expression { 500 expr, err := b.binaryExprToExpression(inScope, be) 501 if err != nil { 502 b.handleErr(err) 503 } 504 return expr 505 } 506 507 func (b *Builder) buildComparison(inScope *scope, c *ast.ComparisonExpr) sql.Expression { 508 left := b.buildScalar(inScope, c.Left) 509 right := b.buildScalar(inScope, c.Right) 510 511 var escape sql.Expression = nil 512 if c.Escape != nil { 513 escape = b.buildScalar(inScope, c.Escape) 514 } 515 516 switch strings.ToLower(c.Operator) { 517 case ast.RegexpStr: 518 return expression.NewRegexp(left, right) 519 case ast.NotRegexpStr: 520 return expression.NewNot(expression.NewRegexp(left, right)) 521 case ast.EqualStr: 522 return expression.NewEquals(left, right) 523 case ast.LessThanStr: 524 return expression.NewLessThan(left, right) 525 case ast.LessEqualStr: 526 return expression.NewLessThanOrEqual(left, right) 527 case ast.GreaterThanStr: 528 return expression.NewGreaterThan(left, right) 529 case ast.GreaterEqualStr: 530 return expression.NewGreaterThanOrEqual(left, right) 531 case ast.NullSafeEqualStr: 532 return expression.NewNullSafeEquals(left, right) 533 case ast.NotEqualStr: 534 return expression.NewNot( 535 expression.NewEquals(left, right), 536 ) 537 case ast.InStr: 538 switch right.(type) { 539 case expression.Tuple: 540 return expression.NewInTuple(left, right) 541 case *plan.Subquery: 542 return plan.NewInSubquery(left, right) 543 default: 544 err := sql.ErrUnsupportedFeature.New(fmt.Sprintf("IN %T", right)) 545 b.handleErr(err) 546 } 547 case ast.NotInStr: 548 switch right.(type) { 549 case expression.Tuple: 550 return expression.NewNotInTuple(left, right) 551 case *plan.Subquery: 552 return plan.NewNotInSubquery(left, right) 553 default: 554 err := sql.ErrUnsupportedFeature.New(fmt.Sprintf("NOT IN %T", right)) 555 b.handleErr(err) 556 } 557 case ast.LikeStr: 558 return expression.NewLike(left, right, escape) 559 case ast.NotLikeStr: 560 return expression.NewNot(expression.NewLike(left, right, escape)) 561 default: 562 err := sql.ErrUnsupportedFeature.New(c.Operator) 563 b.handleErr(err) 564 } 565 return nil 566 } 567 568 func hasColumnType(e sql.Expression) (sql.Type, bool) { 569 var typ sql.Type 570 sql.Inspect(e, func(e sql.Expression) bool { 571 if col, ok := e.(*expression.GetField); ok { 572 typ = col.Type() 573 return false 574 } 575 return true 576 }) 577 return typ, typ != nil 578 } 579 580 func (b *Builder) buildCaseExpr(inScope *scope, e *ast.CaseExpr) sql.Expression { 581 expr, err := b.caseExprToExpression(inScope, e) 582 if err != nil { 583 b.handleErr(err) 584 } 585 return expr 586 } 587 588 func (b *Builder) buildIsExprToExpression(inScope *scope, c *ast.IsExpr) sql.Expression { 589 e := b.buildScalar(inScope, c.Expr) 590 switch strings.ToLower(c.Operator) { 591 case ast.IsNullStr: 592 return expression.NewIsNull(e) 593 case ast.IsNotNullStr: 594 return expression.NewNot(expression.NewIsNull(e)) 595 case ast.IsTrueStr: 596 return expression.NewIsTrue(e) 597 case ast.IsFalseStr: 598 return expression.NewIsFalse(e) 599 case ast.IsNotTrueStr: 600 return expression.NewNot(expression.NewIsTrue(e)) 601 case ast.IsNotFalseStr: 602 return expression.NewNot(expression.NewIsFalse(e)) 603 default: 604 err := sql.ErrUnsupportedSyntax.New(ast.String(c)) 605 b.handleErr(err) 606 } 607 return nil 608 } 609 610 func (b *Builder) binaryExprToExpression(inScope *scope, be *ast.BinaryExpr) (sql.Expression, error) { 611 l := b.buildScalar(inScope, be.Left) 612 r := b.buildScalar(inScope, be.Right) 613 614 operator := strings.ToLower(be.Operator) 615 switch operator { 616 case 617 ast.PlusStr, 618 ast.MinusStr, 619 ast.MultStr, 620 ast.DivStr, 621 ast.ShiftLeftStr, 622 ast.ShiftRightStr, 623 ast.BitAndStr, 624 ast.BitOrStr, 625 ast.BitXorStr, 626 ast.IntDivStr, 627 ast.ModStr: 628 629 _, lok := l.(*expression.Interval) 630 _, rok := r.(*expression.Interval) 631 if lok && be.Operator == "-" { 632 return nil, sql.ErrUnsupportedSyntax.New("subtracting from an interval") 633 } else if (lok || rok) && be.Operator != "+" && be.Operator != "-" { 634 return nil, sql.ErrUnsupportedSyntax.New("only + and - can be used to add or subtract intervals from dates") 635 } else if lok && rok { 636 return nil, sql.ErrUnsupportedSyntax.New("intervals cannot be added or subtracted from other intervals") 637 } 638 639 switch operator { 640 case ast.DivStr: 641 return expression.NewDiv(l, r), nil 642 case ast.ModStr: 643 return expression.NewMod(l, r), nil 644 case ast.BitAndStr, ast.BitOrStr, ast.BitXorStr, ast.ShiftRightStr, ast.ShiftLeftStr: 645 return expression.NewBitOp(l, r, be.Operator), nil 646 case ast.IntDivStr: 647 return expression.NewIntDiv(l, r), nil 648 case ast.MultStr: 649 return expression.NewMult(l, r), nil 650 case ast.PlusStr: 651 return expression.NewPlus(l, r), nil 652 case ast.MinusStr: 653 return expression.NewMinus(l, r), nil 654 default: 655 return nil, sql.ErrUnsupportedSyntax.New("unsupported operator: %s", be.Operator) 656 } 657 658 case ast.JSONExtractOp, ast.JSONUnquoteExtractOp: 659 jsonExtract, err := json.NewJSONExtract(l, r) 660 if err != nil { 661 return nil, err 662 } 663 664 if operator == ast.JSONUnquoteExtractOp { 665 return json.NewJSONUnquote(jsonExtract), nil 666 } 667 return jsonExtract, nil 668 669 default: 670 return nil, sql.ErrUnsupportedFeature.New(be.Operator) 671 } 672 } 673 674 func (b *Builder) caseExprToExpression(inScope *scope, e *ast.CaseExpr) (sql.Expression, error) { 675 var expr sql.Expression 676 677 if e.Expr != nil { 678 expr = b.buildScalar(inScope, e.Expr) 679 } 680 681 var branches []expression.CaseBranch 682 for _, w := range e.Whens { 683 var cond sql.Expression 684 cond = b.buildScalar(inScope, w.Cond) 685 686 var val sql.Expression 687 val = b.buildScalar(inScope, w.Val) 688 689 branches = append(branches, expression.CaseBranch{ 690 Cond: cond, 691 Value: val, 692 }) 693 } 694 695 var elseExpr sql.Expression 696 if e.Else != nil { 697 elseExpr = b.buildScalar(inScope, e.Else) 698 } 699 700 return expression.NewCase(expr, branches, elseExpr), nil 701 } 702 703 func (b *Builder) intervalExprToExpression(inScope *scope, e *ast.IntervalExpr) *expression.Interval { 704 expr := b.buildScalar(inScope, e.Expr) 705 return expression.NewInterval(expr, e.Unit) 706 } 707 708 // Convert an integer, represented by the specified string in the specified 709 // base, to its smallest representation possible, out of: 710 // int8, uint8, int16, uint16, int32, uint32, int64 and uint64 711 func (b *Builder) convertInt(value string, base int) *expression.Literal { 712 if i8, err := strconv.ParseInt(value, base, 8); err == nil { 713 return expression.NewLiteral(int8(i8), types.Int8) 714 } 715 if ui8, err := strconv.ParseUint(value, base, 8); err == nil { 716 return expression.NewLiteral(uint8(ui8), types.Uint8) 717 } 718 if i16, err := strconv.ParseInt(value, base, 16); err == nil { 719 return expression.NewLiteral(int16(i16), types.Int16) 720 } 721 if ui16, err := strconv.ParseUint(value, base, 16); err == nil { 722 return expression.NewLiteral(uint16(ui16), types.Uint16) 723 } 724 if i32, err := strconv.ParseInt(value, base, 32); err == nil { 725 return expression.NewLiteral(int32(i32), types.Int32) 726 } 727 if ui32, err := strconv.ParseUint(value, base, 32); err == nil { 728 return expression.NewLiteral(uint32(ui32), types.Uint32) 729 } 730 if i64, err := strconv.ParseInt(value, base, 64); err == nil { 731 return expression.NewLiteral(int64(i64), types.Int64) 732 } 733 if ui64, err := strconv.ParseUint(value, base, 64); err == nil { 734 return expression.NewLiteral(uint64(ui64), types.Uint64) 735 } 736 if decimal, _, err := types.InternalDecimalType.Convert(value); err == nil { 737 return expression.NewLiteral(decimal, types.InternalDecimalType) 738 } 739 740 b.handleErr(fmt.Errorf("could not convert %s to any numerical type", value)) 741 return nil 742 } 743 744 func (b *Builder) ConvertVal(v *ast.SQLVal) sql.Expression { 745 switch v.Type { 746 case ast.StrVal: 747 return expression.NewLiteral(string(v.Val), types.CreateLongText(b.ctx.GetCollation())) 748 case ast.IntVal: 749 return b.convertInt(string(v.Val), 10) 750 case ast.FloatVal: 751 // any float value is parsed as decimal except when the value has scientific notation 752 ogVal := strings.ToLower(string(v.Val)) 753 if strings.Contains(ogVal, "e") { 754 val, err := strconv.ParseFloat(string(v.Val), 64) 755 if err != nil { 756 b.handleErr(err) 757 } 758 return expression.NewLiteral(val, types.Float64) 759 } 760 761 // using DECIMAL data type avoids precision error of rounded up float64 value 762 if ps := strings.Split(string(v.Val), "."); len(ps) == 2 { 763 p, s := expression.GetDecimalPrecisionAndScale(ogVal) 764 dt, err := types.CreateDecimalType(p, s) 765 if err != nil { 766 return expression.NewLiteral(string(v.Val), types.CreateLongText(b.ctx.GetCollation())) 767 } 768 dVal, _, err := dt.Convert(ogVal) 769 if err != nil { 770 return expression.NewLiteral(string(v.Val), types.CreateLongText(b.ctx.GetCollation())) 771 } 772 return expression.NewLiteral(dVal, dt) 773 } else { 774 // if the value is not float type - this should not happen 775 return b.convertInt(string(v.Val), 10) 776 } 777 case ast.HexNum: 778 //TODO: binary collation? 779 v := strings.ToLower(string(v.Val)) 780 if strings.HasPrefix(v, "0x") { 781 v = v[2:] 782 } else if strings.HasPrefix(v, "x") { 783 v = strings.Trim(v[1:], "'") 784 } 785 786 // pad string to even length 787 if len(v)%2 == 1 { 788 v = "0" + v 789 } 790 791 val, err := hex.DecodeString(v) 792 if err != nil { 793 b.handleErr(err) 794 } 795 return expression.NewLiteral(val, types.LongBlob) 796 case ast.HexVal: 797 //TODO: binary collation? 798 val, err := v.HexDecode() 799 if err != nil { 800 b.handleErr(err) 801 } 802 return expression.NewLiteral(val, types.LongBlob) 803 case ast.ValArg: 804 name := strings.TrimPrefix(string(v.Val), ":") 805 if b.bindCtx != nil { 806 if b.bindCtx.resolveOnly { 807 return expression.NewBindVar(name) 808 } 809 replacement := b.normalizeValArg(v) 810 return b.buildScalar(&scope{}, replacement) 811 } 812 return expression.NewBindVar(name) 813 case ast.BitVal: 814 if len(v.Val) == 0 { 815 return expression.NewLiteral(0, types.Uint64) 816 } 817 818 res, err := strconv.ParseUint(string(v.Val), 2, 64) 819 if err != nil { 820 b.handleErr(err) 821 } 822 823 return expression.NewLiteral(res, types.Uint64) 824 } 825 826 b.handleErr(sql.ErrInvalidSQLValType.New(v.Type)) 827 return nil 828 } 829 830 // processMatchAgainst returns a new MatchAgainst expression that has had 831 // all of its tables filled in. This essentially grabs the appropriate index 832 // (if it hasn't already been grabbed), and then loads the appropriate 833 // tables that are referenced by the index. The returned expression contains 834 // everything needed to calculate relevancy. 835 // 836 // A fully resolved MatchAgainst expression is also used by the index 837 // filter, since we only need to load the tables once. All steps after this 838 // one can assume that the expression has been fully resolved and is valid. 839 func (b *Builder) buildMatchAgainst(inScope *scope, v *ast.MatchExpr) *expression.MatchAgainst { 840 //TODO: implement proper scope support and remove this check 841 if (inScope.groupBy != nil && inScope.groupBy.hasAggs()) || inScope.windowFuncs != nil { 842 b.handleErr(fmt.Errorf("aggregate and window functions are not yet supported alongside MATCH expressions")) 843 } 844 rts := getTablesByName(inScope.node) 845 var rt *plan.ResolvedTable 846 var matchTable string 847 cols := make([]*expression.GetField, len(v.Columns)) 848 for i, selectExpr := range v.Columns { 849 expr := b.selectExprToExpression(inScope, selectExpr) 850 gf, ok := expr.(*expression.GetField) 851 if !ok { 852 err := sql.ErrFullTextMatchAgainstNotColumns.New() 853 b.handleErr(err) 854 } 855 if rt == nil { 856 matchTable = strings.ToLower(gf.Table()) 857 rt, ok = rts[matchTable] 858 if !ok { 859 // shouldn't be able to resolve expression without table being available 860 panic("shouldn't be able to resolve expression without table being available") 861 } 862 } else if !strings.EqualFold(matchTable, gf.Table()) { 863 err := sql.ErrFullTextMatchAgainstSameTable.New() 864 b.handleErr(err) 865 } 866 cols[i] = gf 867 } 868 matchExpr := b.buildScalar(inScope, v.Expr) 869 var searchModifier fulltext.SearchModifier 870 var err error 871 switch v.Option { 872 case ast.NaturalLanguageModeStr, "": 873 searchModifier = fulltext.SearchModifier_NaturalLanguage 874 case ast.NaturalLanguageModeWithQueryExpansionStr: 875 searchModifier = fulltext.SearchModifier_NaturalLangaugeQueryExpansion 876 err = fmt.Errorf(`"IN NATURAL LANGUAGE MODE WITH QUERY EXPANSION" is not supported yet`) 877 case ast.BooleanModeStr: 878 searchModifier = fulltext.SearchModifier_Boolean 879 err = fmt.Errorf(`"IN BOOLEAN MODE" is not supported yet`) 880 case ast.QueryExpansionStr: 881 searchModifier = fulltext.SearchModifier_QueryExpansion 882 err = fmt.Errorf(`"WITH QUERY EXPANSION" is not supported yet`) 883 default: 884 err = sql.ErrUnsupportedFeature.New(v.Option) 885 } 886 if err != nil { 887 b.handleErr(err) 888 } 889 890 innerTbl := rt.UnderlyingTable() 891 indexedTbl, ok := innerTbl.(sql.IndexAddressableTable) 892 if !ok { 893 err := fmt.Errorf("cannot use MATCH ... AGAINST ... on a table that does not declare indexes") 894 b.handleErr(err) 895 } 896 897 indexes, err := indexedTbl.GetIndexes(b.ctx) 898 if err != nil { 899 b.handleErr(err) 900 } 901 ftIndex := findMatchAgainstIndex(cols, indexes) 902 if ftIndex == nil { 903 err := sql.ErrNoFullTextIndexFound.New(indexedTbl.Name()) 904 b.handleErr(err) 905 } 906 907 // Get the key columns 908 keyCols, err := ftIndex.FullTextKeyColumns(b.ctx) 909 if err != nil { 910 b.handleErr(err) 911 } 912 913 genericCols := make([]sql.Expression, len(cols)) 914 for i, e := range cols { 915 genericCols[i] = e 916 } 917 918 // Grab the pseudo-index table names 919 tableNames, err := ftIndex.FullTextTableNames(b.ctx) 920 if err != nil { 921 b.handleErr(err) 922 } 923 924 fullindexTableNames := [5]string{tableNames.Config, tableNames.Position, tableNames.DocCount, tableNames.GlobalCount, tableNames.RowCount} 925 idxTables := make([]sql.IndexAddressableTable, 5) 926 for i, name := range fullindexTableNames { 927 configTbl, ok, err := rt.SqlDatabase.GetTableInsensitive(b.ctx, name) 928 if err != nil { 929 b.handleErr(err) 930 } 931 if !ok { 932 err := fmt.Errorf("Full-Text index `%s` on table `%s` is linked to table `%s` which could not be found", 933 ftIndex.ID(), indexedTbl.Name(), tableNames.Config) 934 b.handleErr(err) 935 } 936 idxTables[i], ok = configTbl.(sql.IndexAddressableTable) 937 if !ok { 938 err := fmt.Errorf("Full-Text index `%s` on table `%s` requires table `%s` to implement sql.IndexAddressableTable", 939 ftIndex.ID(), indexedTbl.Name(), tableNames.Config) 940 b.handleErr(err) 941 } 942 } 943 944 matchAgainst := expression.NewMatchAgainst(genericCols, matchExpr, searchModifier) 945 matchAgainst.SetIndex(ftIndex) 946 947 return matchAgainst.WithInfo(indexedTbl, idxTables[0], idxTables[1], idxTables[2], idxTables[3], idxTables[4], keyCols) 948 } 949 950 func findMatchAgainstIndex(cols []*expression.GetField, indexes []sql.Index) fulltext.Index { 951 var found fulltext.Index 952 for _, idx := range indexes { 953 idxExprs := idx.Expressions() 954 if !idx.IsFullText() || len(cols) != len(idxExprs) { 955 continue 956 } 957 // check that index expressions match |cols| 958 allMatch := true 959 for _, gf := range cols { 960 var match bool 961 for _, idxExpr := range idxExprs { 962 if gf.String() == idxExpr { 963 match = true 964 break 965 } 966 } 967 if !match { 968 allMatch = false 969 break 970 } 971 } 972 if !allMatch { 973 continue 974 } 975 var ok bool 976 found, ok = idx.(fulltext.Index) 977 if ok { 978 break 979 } 980 } 981 return found 982 }