github.com/dolthub/go-mysql-server@v0.18.0/sql/analyzer/validation_rules.go (about) 1 // Copyright 2020-2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package analyzer 16 17 import ( 18 "fmt" 19 "reflect" 20 "strings" 21 22 "github.com/dolthub/go-mysql-server/sql" 23 "github.com/dolthub/go-mysql-server/sql/analyzer/analyzererrors" 24 "github.com/dolthub/go-mysql-server/sql/expression" 25 "github.com/dolthub/go-mysql-server/sql/expression/function" 26 "github.com/dolthub/go-mysql-server/sql/expression/function/aggregation" 27 "github.com/dolthub/go-mysql-server/sql/plan" 28 "github.com/dolthub/go-mysql-server/sql/transform" 29 "github.com/dolthub/go-mysql-server/sql/types" 30 ) 31 32 // validateLimitAndOffset ensures that only integer literals are used for limit and offset values 33 func validateLimitAndOffset(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) { 34 var err error 35 var i, i64 interface{} 36 transform.Inspect(n, func(n sql.Node) bool { 37 switch n := n.(type) { 38 case *plan.Limit: 39 switch e := n.Limit.(type) { 40 case *expression.Literal: 41 if !types.IsInteger(e.Type()) { 42 err = sql.ErrInvalidType.New(e.Type().String()) 43 return false 44 } 45 i, err = e.Eval(ctx, nil) 46 if err != nil { 47 return false 48 } 49 50 i64, _, err = types.Int64.Convert(i) 51 if err != nil { 52 return false 53 } 54 if i64.(int64) < 0 { 55 err = sql.ErrInvalidSyntax.New("negative limit") 56 return false 57 } 58 case *expression.BindVar, *expression.ProcedureParam: 59 return true 60 default: 61 err = sql.ErrInvalidType.New(e.Type().String()) 62 return false 63 } 64 case *plan.Offset: 65 switch e := n.Offset.(type) { 66 case *expression.Literal: 67 if !types.IsInteger(e.Type()) { 68 err = sql.ErrInvalidType.New(e.Type().String()) 69 return false 70 } 71 i, err = e.Eval(ctx, nil) 72 if err != nil { 73 return false 74 } 75 76 i64, _, err = types.Int64.Convert(i) 77 if err != nil { 78 return false 79 } 80 if i64.(int64) < 0 { 81 err = sql.ErrInvalidSyntax.New("negative offset") 82 return false 83 } 84 case *expression.BindVar, *expression.ProcedureParam: 85 return true 86 default: 87 err = sql.ErrInvalidType.New(e.Type().String()) 88 return false 89 } 90 default: 91 return true 92 } 93 return true 94 }) 95 return n, transform.SameTree, err 96 } 97 98 func validateIsResolved(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) { 99 span, ctx := ctx.Span("validate_is_resolved") 100 defer span.End() 101 102 if !n.Resolved() { 103 return nil, transform.SameTree, unresolvedError(n) 104 } 105 106 return n, transform.SameTree, nil 107 } 108 109 // unresolvedError returns an appropriate error message for the unresolved node given 110 func unresolvedError(n sql.Node) error { 111 var err error 112 var walkFn func(sql.Expression) bool 113 walkFn = func(e sql.Expression) bool { 114 switch e := e.(type) { 115 case *plan.Subquery: 116 transform.InspectExpressions(e.Query, walkFn) 117 if err != nil { 118 return false 119 } 120 } 121 return true 122 } 123 transform.InspectExpressions(n, walkFn) 124 125 if err != nil { 126 return err 127 } 128 return analyzererrors.ErrValidationResolved.New(n) 129 } 130 131 func validateOrderBy(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) { 132 span, ctx := ctx.Span("validate_order_by") 133 defer span.End() 134 135 switch n := n.(type) { 136 case *plan.Sort: 137 for _, field := range n.SortFields { 138 switch field.Column.(type) { 139 case sql.Aggregation: 140 return nil, transform.SameTree, analyzererrors.ErrValidationOrderBy.New() 141 } 142 } 143 } 144 145 return n, transform.SameTree, nil 146 } 147 148 // validateDeleteFrom checks for invalid settings, such as deleting from multiple databases, specifying a delete target 149 // table multiple times, or using a DELETE FROM JOIN without specifying any explicit delete target tables, and returns 150 // an error if any validation issues were detected. 151 func validateDeleteFrom(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) { 152 span, ctx := ctx.Span("validate_order_by") 153 defer span.End() 154 155 var validationError error 156 transform.InspectUp(n, func(n sql.Node) bool { 157 df, ok := n.(*plan.DeleteFrom) 158 if !ok { 159 return false 160 } 161 162 // Check that delete from join only targets tables that exist in the join 163 if df.HasExplicitTargets() { 164 sourceTables := make(map[string]struct{}) 165 transform.Inspect(df.Child, func(node sql.Node) bool { 166 if t, ok := node.(sql.Table); ok { 167 sourceTables[t.Name()] = struct{}{} 168 } 169 return true 170 }) 171 172 for _, target := range df.GetDeleteTargets() { 173 deletable, err := plan.GetDeletable(target) 174 if err != nil { 175 validationError = err 176 return true 177 } 178 tableName := deletable.Name() 179 if _, ok := sourceTables[tableName]; !ok { 180 validationError = fmt.Errorf("table %q not found in DELETE FROM sources", tableName) 181 return true 182 } 183 } 184 } 185 186 // Duplicate explicit target tables or from explicit target tables from multiple databases 187 databases := make(map[string]struct{}) 188 tables := make(map[string]struct{}) 189 if df.HasExplicitTargets() { 190 for _, target := range df.GetDeleteTargets() { 191 // Check for multiple databases 192 databases[plan.GetDatabaseName(target)] = struct{}{} 193 if len(databases) > 1 { 194 validationError = fmt.Errorf("multiple databases specified as delete from targets") 195 return true 196 } 197 198 // Check for duplicate targets 199 nameable, ok := target.(sql.Nameable) 200 if !ok { 201 validationError = fmt.Errorf("target node does not implement sql.Nameable: %T", target) 202 return true 203 } 204 205 if _, ok := tables[nameable.Name()]; ok { 206 validationError = fmt.Errorf("duplicate tables specified as delete from targets") 207 return true 208 } 209 tables[nameable.Name()] = struct{}{} 210 } 211 } 212 213 // DELETE FROM JOIN with no target tables specified 214 deleteFromJoin := false 215 transform.Inspect(df.Child, func(node sql.Node) bool { 216 if _, ok := node.(*plan.JoinNode); ok { 217 deleteFromJoin = true 218 return false 219 } 220 return true 221 }) 222 if deleteFromJoin { 223 if df.HasExplicitTargets() == false { 224 validationError = fmt.Errorf("delete from statement with join requires specifying explicit delete target tables") 225 return true 226 } 227 } 228 return false 229 }) 230 231 if validationError != nil { 232 return nil, transform.SameTree, validationError 233 } else { 234 return n, transform.SameTree, nil 235 } 236 } 237 238 // checkSqlMode checks if the option is set for the Session in ctx 239 func checkSqlMode(ctx *sql.Context, option string) (bool, error) { 240 // session variable overrides global 241 sysVal, err := ctx.Session.GetSessionVariable(ctx, "sql_mode") 242 if err != nil { 243 return false, err 244 } 245 val, ok := sysVal.(string) 246 if !ok { 247 return false, sql.ErrSystemVariableCodeFail.New("sql_mode", val) 248 } 249 return strings.Contains(val, option), nil 250 } 251 252 func validateGroupBy(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) { 253 span, ctx := ctx.Span("validate_group_by") 254 defer span.End() 255 256 // only enforce strict group by when this variable is set 257 if isStrict, err := checkSqlMode(ctx, "ONLY_FULL_GROUP_BY"); err != nil { 258 return n, transform.SameTree, err 259 } else if !isStrict { 260 return n, transform.SameTree, nil 261 } 262 263 var err error 264 var parent sql.Node 265 transform.Inspect(n, func(n sql.Node) bool { 266 defer func() { 267 parent = n 268 }() 269 270 gb, ok := n.(*plan.GroupBy) 271 if !ok { 272 return true 273 } 274 275 switch parent.(type) { 276 case *plan.Having, *plan.Project, *plan.Sort: 277 // TODO: these shouldn't be skipped; you can group by primary key without problem b/c only one value 278 // https://dev.mysql.com/doc/refman/8.0/en/group-by-handling.html#:~:text=The%20query%20is%20valid%20if%20name%20is%20a%20primary%20key 279 return true 280 } 281 282 // Allow the parser use the GroupBy node to eval the aggregation functions 283 // for sql statements that don't make use of the GROUP BY expression. 284 if len(gb.GroupByExprs) == 0 { 285 return true 286 } 287 288 var groupBys []string 289 for _, expr := range gb.GroupByExprs { 290 groupBys = append(groupBys, expr.String()) 291 } 292 293 for _, expr := range gb.SelectedExprs { 294 if _, ok := expr.(sql.Aggregation); !ok { 295 if !expressionReferencesOnlyGroupBys(groupBys, expr) { 296 err = analyzererrors.ErrValidationGroupBy.New(expr.String()) 297 return false 298 } 299 } 300 } 301 return true 302 }) 303 304 return n, transform.SameTree, err 305 } 306 307 func expressionReferencesOnlyGroupBys(groupBys []string, expr sql.Expression) bool { 308 valid := true 309 sql.Inspect(expr, func(expr sql.Expression) bool { 310 switch expr := expr.(type) { 311 case nil, sql.Aggregation, *expression.Literal: 312 return false 313 case *expression.Alias, sql.FunctionExpression: 314 if stringContains(groupBys, expr.String()) { 315 return false 316 } 317 return true 318 // cc: https://dev.mysql.com/doc/refman/8.0/en/group-by-handling.html 319 // Each part of the SelectExpr must refer to the aggregated columns in some way 320 // TODO: this isn't complete, it's overly restrictive. Dependant columns are fine to reference. 321 default: 322 if stringContains(groupBys, expr.String()) { 323 return false 324 } 325 326 if len(expr.Children()) == 0 { 327 valid = false 328 return false 329 } 330 331 return true 332 } 333 }) 334 335 return valid 336 } 337 338 func validateSchemaSource(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) { 339 span, ctx := ctx.Span("validate_schema_source") 340 defer span.End() 341 342 switch n := n.(type) { 343 case *plan.TableAlias: 344 // table aliases should not be validated 345 if child, ok := n.Child.(*plan.ResolvedTable); ok { 346 return n, transform.SameTree, validateSchema(child) 347 } 348 case *plan.ResolvedTable: 349 return n, transform.SameTree, validateSchema(n) 350 } 351 return n, transform.SameTree, nil 352 } 353 354 func validateIndexCreation(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) { 355 span, ctx := ctx.Span("validate_index_creation") 356 defer span.End() 357 358 ci, ok := n.(*plan.CreateIndex) 359 if !ok { 360 return n, transform.SameTree, nil 361 } 362 363 schema := ci.Table.Schema() 364 table := schema[0].Source 365 366 var unknownColumns []string 367 for _, expr := range ci.Exprs { 368 sql.Inspect(expr, func(e sql.Expression) bool { 369 gf, ok := e.(*expression.GetField) 370 if ok { 371 if gf.Table() != table || !schema.Contains(gf.Name(), gf.Table()) { 372 unknownColumns = append(unknownColumns, gf.Name()) 373 } 374 } 375 return true 376 }) 377 } 378 379 if len(unknownColumns) > 0 { 380 return nil, transform.SameTree, analyzererrors.ErrUnknownIndexColumns.New(table, strings.Join(unknownColumns, ", ")) 381 } 382 383 return n, transform.SameTree, nil 384 } 385 386 func validateSchema(t *plan.ResolvedTable) error { 387 for _, col := range t.Schema() { 388 if col.Source == "" { 389 return analyzererrors.ErrValidationSchemaSource.New() 390 } 391 } 392 return nil 393 } 394 395 func validateUnionSchemasMatch(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) { 396 span, ctx := ctx.Span("validate_union_schemas_match") 397 defer span.End() 398 399 var firstmismatch []string 400 transform.Inspect(n, func(n sql.Node) bool { 401 if u, ok := n.(*plan.SetOp); ok { 402 ls := u.Left().Schema() 403 rs := u.Right().Schema() 404 if len(ls) != len(rs) { 405 firstmismatch = []string{ 406 fmt.Sprintf("%d columns", len(ls)), 407 fmt.Sprintf("%d columns", len(rs)), 408 } 409 return false 410 } 411 for i := range ls { 412 if !reflect.DeepEqual(ls[i].Type, rs[i].Type) { 413 firstmismatch = []string{ 414 ls[i].Type.String(), 415 rs[i].Type.String(), 416 } 417 return false 418 } 419 } 420 } 421 return true 422 }) 423 if firstmismatch != nil { 424 return nil, transform.SameTree, analyzererrors.ErrUnionSchemasMatch.New(firstmismatch[0], firstmismatch[1]) 425 } 426 return n, transform.SameTree, nil 427 } 428 429 func validateIntervalUsage(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) { 430 var invalid bool 431 transform.InspectExpressionsWithNode(n, func(node sql.Node, e sql.Expression) bool { 432 // If it's already invalid just skip everything else. 433 if invalid { 434 return false 435 } 436 437 // Interval can be used without DATE_ADD/DATE_SUB functions in CREATE/ALTER EVENTS statements. 438 switch node.(type) { 439 case *plan.CreateEvent, *plan.AlterEvent: 440 return false 441 } 442 443 switch e := e.(type) { 444 case *function.DateAdd, *function.DateSub: 445 return false 446 case *expression.Arithmetic: 447 if e.Op == "+" || e.Op == "-" { 448 return false 449 } 450 case *expression.Interval: 451 invalid = true 452 } 453 454 return true 455 }) 456 457 if invalid { 458 return nil, transform.SameTree, analyzererrors.ErrIntervalInvalidUse.New() 459 } 460 461 return n, transform.SameTree, nil 462 } 463 464 func validateStarExpressions(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) { 465 // Validate that all occurences of the '*' placeholder expression are in a context that makes sense. 466 // 467 // That is, all uses of '*' should be either: 468 // - The top level of an expression. 469 // - The input to a COUNT or JSONARRAY function. 470 // 471 // We do not use plan.InspectExpressions here because we're treating 472 // the top-level expressions of sql.Node differently from subexpressions. 473 var err error 474 transform.Inspect(n, func(n sql.Node) bool { 475 if er, ok := n.(sql.Expressioner); ok { 476 for _, e := range er.Expressions() { 477 // An expression consisting of just a * is allowed. 478 if _, s := e.(*expression.Star); s { 479 return false 480 } 481 // Otherwise, * can only be used inside acceptable aggregation functions. 482 // Detect any uses of * outside such functions. 483 sql.Inspect(e, func(e sql.Expression) bool { 484 if err != nil { 485 return false 486 } 487 switch e.(type) { 488 case *expression.Star: 489 err = sql.ErrStarUnsupported.New() 490 return false 491 case *aggregation.Count, *aggregation.CountDistinct, *aggregation.JsonArray: 492 if _, s := e.Children()[0].(*expression.Star); s { 493 return false 494 } 495 } 496 return true 497 }) 498 } 499 } 500 return err == nil 501 }) 502 if err != nil { 503 return nil, transform.SameTree, err 504 } 505 return n, transform.SameTree, nil 506 } 507 508 func validateOperands(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) { 509 // Validate that the number of columns in an operand or a top level 510 // expression are as expected. The current rules are: 511 // * Every top level expression of a node must have 1 column. 512 // * The following expression nodes are allowed to have `n` columns as 513 // long as `n` matches: 514 // * *plan.InSubquery, *expression.{Equals,NullSafeEquals,GreaterThan,LessThan,GreaterThanOrEqual,LessThanOrEqual} 515 // * *expression.InTuple must have a tuple on the right side, the # of 516 // columns for each element of the tuple must match the number of 517 // columns of the expression on the left. 518 // * Every other expression with operands must have NumColumns == 1. 519 520 // We do not use plan.InspectExpressions here because we're treating 521 // top-level expressions of sql.Node differently from subexpressions. 522 var err error 523 transform.Inspect(n, func(n sql.Node) bool { 524 if n == nil { 525 return false 526 } 527 528 if plan.IsDDLNode(n) { 529 return false 530 } 531 532 if er, ok := n.(sql.Expressioner); ok { 533 for _, e := range er.Expressions() { 534 nc := types.NumColumns(e.Type()) 535 if nc != 1 { 536 if _, ok := er.(*plan.HashLookup); ok { 537 // hash lookup expressions are tuples with >= 1 columns 538 return true 539 } 540 err = sql.ErrInvalidOperandColumns.New(1, nc) 541 return false 542 } 543 sql.Inspect(e, func(e sql.Expression) bool { 544 if e == nil { 545 return err == nil 546 } 547 if err != nil { 548 return false 549 } 550 switch e.(type) { 551 case *plan.InSubquery, *expression.Equals, *expression.NullSafeEquals, *expression.GreaterThan, 552 *expression.LessThan, *expression.GreaterThanOrEqual, *expression.LessThanOrEqual: 553 err = types.ErrIfMismatchedColumns(e.Children()[0].Type(), e.Children()[1].Type()) 554 case *expression.InTuple, *expression.HashInTuple: 555 t, ok := e.Children()[1].(expression.Tuple) 556 if ok && len(t.Children()) == 1 { 557 // A single element Tuple treats itself like the element it contains. 558 err = types.ErrIfMismatchedColumns(e.Children()[0].Type(), e.Children()[1].Type()) 559 } else { 560 err = types.ErrIfMismatchedColumnsInTuple(e.Children()[0].Type(), e.Children()[1].Type()) 561 } 562 case *aggregation.Count, *aggregation.CountDistinct, *aggregation.JsonArray: 563 if _, s := e.Children()[0].(*expression.Star); s { 564 return false 565 } 566 for _, e := range e.Children() { 567 nc := types.NumColumns(e.Type()) 568 if nc != 1 { 569 err = sql.ErrInvalidOperandColumns.New(1, nc) 570 } 571 } 572 case expression.Tuple: 573 // Tuple expressions can contain tuples... 574 case *plan.ExistsSubquery: 575 // Any number of columns are allowed. 576 default: 577 for _, e := range e.Children() { 578 nc := types.NumColumns(e.Type()) 579 if nc != 1 { 580 err = sql.ErrInvalidOperandColumns.New(1, nc) 581 } 582 } 583 } 584 return err == nil 585 }) 586 } 587 } 588 return err == nil 589 }) 590 if err != nil { 591 return nil, transform.SameTree, err 592 } 593 return n, transform.SameTree, nil 594 } 595 596 func validateSubqueryColumns(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) { 597 // Then validate that every subquery has field indexes within the correct range 598 // TODO: Why is this only for subqueries? 599 600 // TODO: Currently disabled. 601 if true { 602 return n, transform.SameTree, nil 603 } 604 605 var outOfRangeIndexExpression sql.Expression 606 var outOfRangeColumns int 607 transform.InspectExpressionsWithNode(n, func(n sql.Node, e sql.Expression) bool { 608 s, ok := e.(*plan.Subquery) 609 if !ok { 610 return true 611 } 612 613 outerScopeRowLen := len(scope.Schema()) + len(Schemas(n.Children())) 614 transform.Inspect(s.Query, func(n sql.Node) bool { 615 if n == nil { 616 return true 617 } 618 // TODO: the schema of the rows seen by children of 619 // these nodes are not reflected in the schema 620 // calculations here. This needs to be rationalized 621 // across the analyzer. 622 switch n := n.(type) { 623 case *plan.JoinNode: 624 return !n.Op.IsLookup() 625 default: 626 } 627 if es, ok := n.(sql.Expressioner); ok { 628 childSchemaLen := len(Schemas(n.Children())) 629 for _, e := range es.Expressions() { 630 sql.Inspect(e, func(e sql.Expression) bool { 631 if gf, ok := e.(*expression.GetField); ok { 632 if gf.Index() >= outerScopeRowLen+childSchemaLen { 633 outOfRangeIndexExpression = gf 634 outOfRangeColumns = outerScopeRowLen + childSchemaLen 635 } 636 } 637 return outOfRangeIndexExpression == nil 638 }) 639 } 640 } 641 return outOfRangeIndexExpression == nil 642 }) 643 return outOfRangeIndexExpression == nil 644 }) 645 if outOfRangeIndexExpression != nil { 646 return nil, transform.SameTree, analyzererrors.ErrSubqueryFieldIndex.New(outOfRangeIndexExpression, outOfRangeColumns) 647 } 648 649 return n, transform.SameTree, nil 650 } 651 652 func stringContains(strs []string, target string) bool { 653 lowerTarget := strings.ToLower(target) 654 for _, s := range strs { 655 if lowerTarget == strings.ToLower(s) { 656 return true 657 } 658 } 659 return false 660 } 661 662 func tableColsContains(strs []tableCol, target tableCol) bool { 663 for _, s := range strs { 664 if s == target { 665 return true 666 } 667 } 668 return false 669 } 670 671 // validateReadOnlyDatabase invalidates queries that attempt to write to ReadOnlyDatabases. 672 func validateReadOnlyDatabase(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) { 673 valid := true 674 var readOnlyDB sql.ReadOnlyDatabase 675 enforceReadOnly := scope.EnforcesReadOnly() 676 677 // if a ReadOnlyDatabase is found, invalidate the query 678 readOnlyDBSearch := func(node sql.Node) bool { 679 if rt, ok := node.(*plan.ResolvedTable); ok { 680 if ro, ok := rt.SqlDatabase.(sql.ReadOnlyDatabase); ok { 681 if ro.IsReadOnly() { 682 readOnlyDB = ro 683 valid = false 684 } else if enforceReadOnly { 685 valid = false 686 } 687 } 688 } 689 return valid 690 } 691 692 transform.Inspect(n, func(node sql.Node) bool { 693 switch n := n.(type) { 694 case *plan.DeleteFrom, *plan.Update, *plan.LockTables, *plan.UnlockTables: 695 transform.Inspect(node, readOnlyDBSearch) 696 return false 697 698 case *plan.InsertInto: 699 // ReadOnlyDatabase can be an insertion Source, 700 // only inspect the Destination tree 701 transform.Inspect(n.Destination, readOnlyDBSearch) 702 return false 703 704 case *plan.CreateTable: 705 if ro, ok := n.Database().(sql.ReadOnlyDatabase); ok { 706 if ro.IsReadOnly() { 707 readOnlyDB = ro 708 valid = false 709 } else if enforceReadOnly { 710 valid = false 711 } 712 } 713 // "CREATE TABLE ... LIKE ..." and 714 // "CREATE TABLE ... AS ..." 715 // can both use ReadOnlyDatabases as a source, 716 // so don't descend here. 717 return false 718 719 default: 720 // CreateTable is the only DDL node allowed 721 // to contain a ReadOnlyDatabase 722 if plan.IsDDLNode(n) { 723 transform.Inspect(n, readOnlyDBSearch) 724 return false 725 } 726 } 727 728 return valid 729 }) 730 if !valid { 731 if enforceReadOnly { 732 return nil, transform.SameTree, sql.ErrProcedureCallAsOfReadOnly.New() 733 } else { 734 return nil, transform.SameTree, analyzererrors.ErrReadOnlyDatabase.New(readOnlyDB.Name()) 735 } 736 } 737 738 return n, transform.SameTree, nil 739 } 740 741 // validateReadOnlyTransaction invalidates read only transactions that try to perform improper write operations. 742 func validateReadOnlyTransaction(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) { 743 t := ctx.GetTransaction() 744 745 if t == nil { 746 return n, transform.SameTree, nil 747 } 748 749 // If this is a normal read write transaction don't enforce read-only. Otherwise we must prevent an invalid query. 750 if !t.IsReadOnly() && !scope.EnforcesReadOnly() { 751 return n, transform.SameTree, nil 752 } 753 754 valid := true 755 756 isTempTable := func(table sql.Table) bool { 757 tt, isTempTable := table.(sql.TemporaryTable) 758 if !isTempTable { 759 valid = false 760 } 761 762 return tt.IsTemporary() 763 } 764 765 temporaryTableSearch := func(node sql.Node) bool { 766 if rt, ok := node.(*plan.ResolvedTable); ok { 767 valid = isTempTable(rt.Table) 768 } 769 return valid 770 } 771 772 transform.Inspect(n, func(node sql.Node) bool { 773 switch n := n.(type) { 774 case *plan.DeleteFrom, *plan.Update, *plan.UnlockTables: 775 transform.Inspect(node, temporaryTableSearch) 776 return false 777 case *plan.InsertInto: 778 transform.Inspect(n.Destination, temporaryTableSearch) 779 return false 780 case *plan.LockTables: 781 // TODO: Technically we should allow for the locking of temporary tables but the LockTables implementation 782 // needs substantial refactoring. 783 valid = false 784 return false 785 case *plan.CreateTable: 786 // MySQL explicitly blocks the creation of temporary tables in a read only transaction. 787 if n.Temporary() == plan.IsTempTable { 788 valid = false 789 } 790 791 return false 792 default: 793 // DDL statements have an implicit commits which makes them valid to be executed in READ ONLY transactions. 794 if plan.IsDDLNode(n) { 795 valid = true 796 return false 797 } 798 799 return valid 800 } 801 }) 802 803 if !valid { 804 return nil, transform.SameTree, sql.ErrReadOnlyTransaction.New() 805 } 806 807 return n, transform.SameTree, nil 808 } 809 810 // validateAggregations returns an error if an Aggregation expression has been used in 811 // an invalid way, such as appearing outside of a GroupBy or Window node, or if an aggregate 812 // function is used with the implicit all-rows grouping and contains projected expressions with 813 // window aggregation functions that reference non-aggregated columns. Only GroupBy and Window 814 // nodes know how to evaluate Aggregation expressions. 815 // 816 // See https://github.com/dolthub/go-mysql-server/issues/542 for some queries 817 // that should be supported but that currently trigger this validation because 818 // aggregation expressions end up in the wrong place. 819 func validateAggregations(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) { 820 var validationErr error 821 transform.Inspect(n, func(n sql.Node) bool { 822 switch n := n.(type) { 823 case *plan.GroupBy: 824 validationErr = checkForAggregationFunctions(n.GroupByExprs) 825 case *plan.Window: 826 validationErr = checkForNonAggregatedColumnReferences(n) 827 case sql.Expressioner: 828 validationErr = checkForAggregationFunctions(n.Expressions()) 829 default: 830 } 831 return validationErr == nil 832 }) 833 834 return n, transform.SameTree, validationErr 835 } 836 837 // checkForAggregationFunctions returns an ErrAggregationUnsupported error if any aggregation 838 // functions are found in the specified expressions. 839 func checkForAggregationFunctions(exprs []sql.Expression) error { 840 var validationErr error 841 for _, e := range exprs { 842 sql.Inspect(e, func(ie sql.Expression) bool { 843 if _, ok := ie.(sql.Aggregation); ok { 844 validationErr = sql.ErrAggregationUnsupported.New(e.String()) 845 } 846 return validationErr == nil 847 }) 848 } 849 return validationErr 850 } 851 852 // checkForNonAggregatedColumnReferences returns an ErrNonAggregatedColumnWithoutGroupBy error 853 // if an aggregate function with the implicit/all-rows grouping is mixed with aggregate window 854 // functions that reference a non-aggregated column. 855 // You cannot mix aggregations on the implicit/all-rows grouping with window aggregations. 856 func checkForNonAggregatedColumnReferences(w *plan.Window) error { 857 for _, expr := range w.ProjectedExprs() { 858 if agg, ok := expr.(sql.Aggregation); ok { 859 if agg.Window() == nil { 860 index, gf := findFirstWindowAggregationColumnReference(w) 861 862 if index >= 0 { 863 return sql.ErrNonAggregatedColumnWithoutGroupBy.New(index, gf.String()) 864 } else { 865 // We should always have an index and GetField value to use, but just in case 866 // something changes that, return a similar error message without those details. 867 return fmt.Errorf("in aggregated query without GROUP BY, expression in " + 868 "SELECT list contains nonaggregated column; " + 869 "this is incompatible with sql_mode=only_full_group_by") 870 } 871 } 872 } 873 } 874 return nil 875 } 876 877 // findFirstWindowAggregationColumnReference returns the index and GetField expression for the 878 // first column reference in the first window aggregation function in the specified node's 879 // projection expressions. If no window aggregation function with a column reference is found, 880 // (-1, nil) is returned. This information is needed to populate an 881 // ErrNonAggregatedColumnWithoutGroupBy error. 882 func findFirstWindowAggregationColumnReference(w *plan.Window) (index int, gf *expression.GetField) { 883 for index, expr := range w.ProjectedExprs() { 884 var firstColumnRef *expression.GetField 885 886 transform.InspectExpr(expr, func(e sql.Expression) bool { 887 if windowAgg, ok := e.(sql.WindowAggregation); ok { 888 transform.InspectExpr(windowAgg, func(e sql.Expression) bool { 889 if gf, ok := e.(*expression.GetField); ok { 890 firstColumnRef = gf 891 return true 892 } 893 return false 894 }) 895 return firstColumnRef != nil 896 } 897 return false 898 }) 899 900 if firstColumnRef != nil { 901 return index, firstColumnRef 902 } 903 } 904 905 return -1, nil 906 } 907 908 func validateExprSem(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) { 909 var err error 910 transform.InspectExpressions(n, func(e sql.Expression) bool { 911 err = validateSem(e) 912 return err == nil 913 }) 914 return n, transform.SameTree, err 915 } 916 917 // validateSem is a way to add validation logic for 918 // specific expression types. 919 // todo(max): Refactor and consolidate validation so it can 920 // run before the rest of analysis. Add more expression types. 921 // Add node equivalent. 922 func validateSem(e sql.Expression) error { 923 switch e := e.(type) { 924 case *expression.And: 925 if err := logicalSem(e.BinaryExpressionStub); err != nil { 926 return err 927 } 928 case *expression.Or: 929 if err := logicalSem(e.BinaryExpressionStub); err != nil { 930 return err 931 } 932 default: 933 } 934 return nil 935 } 936 937 func logicalSem(e expression.BinaryExpressionStub) error { 938 if lc := fds(e.LeftChild); lc != 1 { 939 return sql.ErrInvalidOperandColumns.New(1, lc) 940 } 941 if rc := fds(e.RightChild); rc != 1 { 942 return sql.ErrInvalidOperandColumns.New(1, rc) 943 } 944 return nil 945 } 946 947 // fds counts the functional dependencies of an expression. 948 // todo(max): input/output fd's should be part of the expression 949 // interface. 950 func fds(e sql.Expression) int { 951 switch e.(type) { 952 case *expression.UnresolvedColumn: 953 return 1 954 case *expression.UnresolvedFunction: 955 return 1 956 default: 957 return types.NumColumns(e.Type()) 958 } 959 }