github.com/dolthub/go-mysql-server@v0.18.0/sql/analyzer/validation_rules.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package analyzer
    16  
    17  import (
    18  	"fmt"
    19  	"reflect"
    20  	"strings"
    21  
    22  	"github.com/dolthub/go-mysql-server/sql"
    23  	"github.com/dolthub/go-mysql-server/sql/analyzer/analyzererrors"
    24  	"github.com/dolthub/go-mysql-server/sql/expression"
    25  	"github.com/dolthub/go-mysql-server/sql/expression/function"
    26  	"github.com/dolthub/go-mysql-server/sql/expression/function/aggregation"
    27  	"github.com/dolthub/go-mysql-server/sql/plan"
    28  	"github.com/dolthub/go-mysql-server/sql/transform"
    29  	"github.com/dolthub/go-mysql-server/sql/types"
    30  )
    31  
    32  // validateLimitAndOffset ensures that only integer literals are used for limit and offset values
    33  func validateLimitAndOffset(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) {
    34  	var err error
    35  	var i, i64 interface{}
    36  	transform.Inspect(n, func(n sql.Node) bool {
    37  		switch n := n.(type) {
    38  		case *plan.Limit:
    39  			switch e := n.Limit.(type) {
    40  			case *expression.Literal:
    41  				if !types.IsInteger(e.Type()) {
    42  					err = sql.ErrInvalidType.New(e.Type().String())
    43  					return false
    44  				}
    45  				i, err = e.Eval(ctx, nil)
    46  				if err != nil {
    47  					return false
    48  				}
    49  
    50  				i64, _, err = types.Int64.Convert(i)
    51  				if err != nil {
    52  					return false
    53  				}
    54  				if i64.(int64) < 0 {
    55  					err = sql.ErrInvalidSyntax.New("negative limit")
    56  					return false
    57  				}
    58  			case *expression.BindVar, *expression.ProcedureParam:
    59  				return true
    60  			default:
    61  				err = sql.ErrInvalidType.New(e.Type().String())
    62  				return false
    63  			}
    64  		case *plan.Offset:
    65  			switch e := n.Offset.(type) {
    66  			case *expression.Literal:
    67  				if !types.IsInteger(e.Type()) {
    68  					err = sql.ErrInvalidType.New(e.Type().String())
    69  					return false
    70  				}
    71  				i, err = e.Eval(ctx, nil)
    72  				if err != nil {
    73  					return false
    74  				}
    75  
    76  				i64, _, err = types.Int64.Convert(i)
    77  				if err != nil {
    78  					return false
    79  				}
    80  				if i64.(int64) < 0 {
    81  					err = sql.ErrInvalidSyntax.New("negative offset")
    82  					return false
    83  				}
    84  			case *expression.BindVar, *expression.ProcedureParam:
    85  				return true
    86  			default:
    87  				err = sql.ErrInvalidType.New(e.Type().String())
    88  				return false
    89  			}
    90  		default:
    91  			return true
    92  		}
    93  		return true
    94  	})
    95  	return n, transform.SameTree, err
    96  }
    97  
    98  func validateIsResolved(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) {
    99  	span, ctx := ctx.Span("validate_is_resolved")
   100  	defer span.End()
   101  
   102  	if !n.Resolved() {
   103  		return nil, transform.SameTree, unresolvedError(n)
   104  	}
   105  
   106  	return n, transform.SameTree, nil
   107  }
   108  
   109  // unresolvedError returns an appropriate error message for the unresolved node given
   110  func unresolvedError(n sql.Node) error {
   111  	var err error
   112  	var walkFn func(sql.Expression) bool
   113  	walkFn = func(e sql.Expression) bool {
   114  		switch e := e.(type) {
   115  		case *plan.Subquery:
   116  			transform.InspectExpressions(e.Query, walkFn)
   117  			if err != nil {
   118  				return false
   119  			}
   120  		}
   121  		return true
   122  	}
   123  	transform.InspectExpressions(n, walkFn)
   124  
   125  	if err != nil {
   126  		return err
   127  	}
   128  	return analyzererrors.ErrValidationResolved.New(n)
   129  }
   130  
   131  func validateOrderBy(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) {
   132  	span, ctx := ctx.Span("validate_order_by")
   133  	defer span.End()
   134  
   135  	switch n := n.(type) {
   136  	case *plan.Sort:
   137  		for _, field := range n.SortFields {
   138  			switch field.Column.(type) {
   139  			case sql.Aggregation:
   140  				return nil, transform.SameTree, analyzererrors.ErrValidationOrderBy.New()
   141  			}
   142  		}
   143  	}
   144  
   145  	return n, transform.SameTree, nil
   146  }
   147  
   148  // validateDeleteFrom checks for invalid settings, such as deleting from multiple databases, specifying a delete target
   149  // table multiple times, or using a DELETE FROM JOIN without specifying any explicit delete target tables, and returns
   150  // an error if any validation issues were detected.
   151  func validateDeleteFrom(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) {
   152  	span, ctx := ctx.Span("validate_order_by")
   153  	defer span.End()
   154  
   155  	var validationError error
   156  	transform.InspectUp(n, func(n sql.Node) bool {
   157  		df, ok := n.(*plan.DeleteFrom)
   158  		if !ok {
   159  			return false
   160  		}
   161  
   162  		// Check that delete from join only targets tables that exist in the join
   163  		if df.HasExplicitTargets() {
   164  			sourceTables := make(map[string]struct{})
   165  			transform.Inspect(df.Child, func(node sql.Node) bool {
   166  				if t, ok := node.(sql.Table); ok {
   167  					sourceTables[t.Name()] = struct{}{}
   168  				}
   169  				return true
   170  			})
   171  
   172  			for _, target := range df.GetDeleteTargets() {
   173  				deletable, err := plan.GetDeletable(target)
   174  				if err != nil {
   175  					validationError = err
   176  					return true
   177  				}
   178  				tableName := deletable.Name()
   179  				if _, ok := sourceTables[tableName]; !ok {
   180  					validationError = fmt.Errorf("table %q not found in DELETE FROM sources", tableName)
   181  					return true
   182  				}
   183  			}
   184  		}
   185  
   186  		// Duplicate explicit target tables or from explicit target tables from multiple databases
   187  		databases := make(map[string]struct{})
   188  		tables := make(map[string]struct{})
   189  		if df.HasExplicitTargets() {
   190  			for _, target := range df.GetDeleteTargets() {
   191  				// Check for multiple databases
   192  				databases[plan.GetDatabaseName(target)] = struct{}{}
   193  				if len(databases) > 1 {
   194  					validationError = fmt.Errorf("multiple databases specified as delete from targets")
   195  					return true
   196  				}
   197  
   198  				// Check for duplicate targets
   199  				nameable, ok := target.(sql.Nameable)
   200  				if !ok {
   201  					validationError = fmt.Errorf("target node does not implement sql.Nameable: %T", target)
   202  					return true
   203  				}
   204  
   205  				if _, ok := tables[nameable.Name()]; ok {
   206  					validationError = fmt.Errorf("duplicate tables specified as delete from targets")
   207  					return true
   208  				}
   209  				tables[nameable.Name()] = struct{}{}
   210  			}
   211  		}
   212  
   213  		// DELETE FROM JOIN with no target tables specified
   214  		deleteFromJoin := false
   215  		transform.Inspect(df.Child, func(node sql.Node) bool {
   216  			if _, ok := node.(*plan.JoinNode); ok {
   217  				deleteFromJoin = true
   218  				return false
   219  			}
   220  			return true
   221  		})
   222  		if deleteFromJoin {
   223  			if df.HasExplicitTargets() == false {
   224  				validationError = fmt.Errorf("delete from statement with join requires specifying explicit delete target tables")
   225  				return true
   226  			}
   227  		}
   228  		return false
   229  	})
   230  
   231  	if validationError != nil {
   232  		return nil, transform.SameTree, validationError
   233  	} else {
   234  		return n, transform.SameTree, nil
   235  	}
   236  }
   237  
   238  // checkSqlMode checks if the option is set for the Session in ctx
   239  func checkSqlMode(ctx *sql.Context, option string) (bool, error) {
   240  	// session variable overrides global
   241  	sysVal, err := ctx.Session.GetSessionVariable(ctx, "sql_mode")
   242  	if err != nil {
   243  		return false, err
   244  	}
   245  	val, ok := sysVal.(string)
   246  	if !ok {
   247  		return false, sql.ErrSystemVariableCodeFail.New("sql_mode", val)
   248  	}
   249  	return strings.Contains(val, option), nil
   250  }
   251  
   252  func validateGroupBy(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) {
   253  	span, ctx := ctx.Span("validate_group_by")
   254  	defer span.End()
   255  
   256  	// only enforce strict group by when this variable is set
   257  	if isStrict, err := checkSqlMode(ctx, "ONLY_FULL_GROUP_BY"); err != nil {
   258  		return n, transform.SameTree, err
   259  	} else if !isStrict {
   260  		return n, transform.SameTree, nil
   261  	}
   262  
   263  	var err error
   264  	var parent sql.Node
   265  	transform.Inspect(n, func(n sql.Node) bool {
   266  		defer func() {
   267  			parent = n
   268  		}()
   269  
   270  		gb, ok := n.(*plan.GroupBy)
   271  		if !ok {
   272  			return true
   273  		}
   274  
   275  		switch parent.(type) {
   276  		case *plan.Having, *plan.Project, *plan.Sort:
   277  			// TODO: these shouldn't be skipped; you can group by primary key without problem b/c only one value
   278  			// https://dev.mysql.com/doc/refman/8.0/en/group-by-handling.html#:~:text=The%20query%20is%20valid%20if%20name%20is%20a%20primary%20key
   279  			return true
   280  		}
   281  
   282  		// Allow the parser use the GroupBy node to eval the aggregation functions
   283  		// for sql statements that don't make use of the GROUP BY expression.
   284  		if len(gb.GroupByExprs) == 0 {
   285  			return true
   286  		}
   287  
   288  		var groupBys []string
   289  		for _, expr := range gb.GroupByExprs {
   290  			groupBys = append(groupBys, expr.String())
   291  		}
   292  
   293  		for _, expr := range gb.SelectedExprs {
   294  			if _, ok := expr.(sql.Aggregation); !ok {
   295  				if !expressionReferencesOnlyGroupBys(groupBys, expr) {
   296  					err = analyzererrors.ErrValidationGroupBy.New(expr.String())
   297  					return false
   298  				}
   299  			}
   300  		}
   301  		return true
   302  	})
   303  
   304  	return n, transform.SameTree, err
   305  }
   306  
   307  func expressionReferencesOnlyGroupBys(groupBys []string, expr sql.Expression) bool {
   308  	valid := true
   309  	sql.Inspect(expr, func(expr sql.Expression) bool {
   310  		switch expr := expr.(type) {
   311  		case nil, sql.Aggregation, *expression.Literal:
   312  			return false
   313  		case *expression.Alias, sql.FunctionExpression:
   314  			if stringContains(groupBys, expr.String()) {
   315  				return false
   316  			}
   317  			return true
   318  		// cc: https://dev.mysql.com/doc/refman/8.0/en/group-by-handling.html
   319  		// Each part of the SelectExpr must refer to the aggregated columns in some way
   320  		// TODO: this isn't complete, it's overly restrictive. Dependant columns are fine to reference.
   321  		default:
   322  			if stringContains(groupBys, expr.String()) {
   323  				return false
   324  			}
   325  
   326  			if len(expr.Children()) == 0 {
   327  				valid = false
   328  				return false
   329  			}
   330  
   331  			return true
   332  		}
   333  	})
   334  
   335  	return valid
   336  }
   337  
   338  func validateSchemaSource(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) {
   339  	span, ctx := ctx.Span("validate_schema_source")
   340  	defer span.End()
   341  
   342  	switch n := n.(type) {
   343  	case *plan.TableAlias:
   344  		// table aliases should not be validated
   345  		if child, ok := n.Child.(*plan.ResolvedTable); ok {
   346  			return n, transform.SameTree, validateSchema(child)
   347  		}
   348  	case *plan.ResolvedTable:
   349  		return n, transform.SameTree, validateSchema(n)
   350  	}
   351  	return n, transform.SameTree, nil
   352  }
   353  
   354  func validateIndexCreation(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) {
   355  	span, ctx := ctx.Span("validate_index_creation")
   356  	defer span.End()
   357  
   358  	ci, ok := n.(*plan.CreateIndex)
   359  	if !ok {
   360  		return n, transform.SameTree, nil
   361  	}
   362  
   363  	schema := ci.Table.Schema()
   364  	table := schema[0].Source
   365  
   366  	var unknownColumns []string
   367  	for _, expr := range ci.Exprs {
   368  		sql.Inspect(expr, func(e sql.Expression) bool {
   369  			gf, ok := e.(*expression.GetField)
   370  			if ok {
   371  				if gf.Table() != table || !schema.Contains(gf.Name(), gf.Table()) {
   372  					unknownColumns = append(unknownColumns, gf.Name())
   373  				}
   374  			}
   375  			return true
   376  		})
   377  	}
   378  
   379  	if len(unknownColumns) > 0 {
   380  		return nil, transform.SameTree, analyzererrors.ErrUnknownIndexColumns.New(table, strings.Join(unknownColumns, ", "))
   381  	}
   382  
   383  	return n, transform.SameTree, nil
   384  }
   385  
   386  func validateSchema(t *plan.ResolvedTable) error {
   387  	for _, col := range t.Schema() {
   388  		if col.Source == "" {
   389  			return analyzererrors.ErrValidationSchemaSource.New()
   390  		}
   391  	}
   392  	return nil
   393  }
   394  
   395  func validateUnionSchemasMatch(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) {
   396  	span, ctx := ctx.Span("validate_union_schemas_match")
   397  	defer span.End()
   398  
   399  	var firstmismatch []string
   400  	transform.Inspect(n, func(n sql.Node) bool {
   401  		if u, ok := n.(*plan.SetOp); ok {
   402  			ls := u.Left().Schema()
   403  			rs := u.Right().Schema()
   404  			if len(ls) != len(rs) {
   405  				firstmismatch = []string{
   406  					fmt.Sprintf("%d columns", len(ls)),
   407  					fmt.Sprintf("%d columns", len(rs)),
   408  				}
   409  				return false
   410  			}
   411  			for i := range ls {
   412  				if !reflect.DeepEqual(ls[i].Type, rs[i].Type) {
   413  					firstmismatch = []string{
   414  						ls[i].Type.String(),
   415  						rs[i].Type.String(),
   416  					}
   417  					return false
   418  				}
   419  			}
   420  		}
   421  		return true
   422  	})
   423  	if firstmismatch != nil {
   424  		return nil, transform.SameTree, analyzererrors.ErrUnionSchemasMatch.New(firstmismatch[0], firstmismatch[1])
   425  	}
   426  	return n, transform.SameTree, nil
   427  }
   428  
   429  func validateIntervalUsage(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) {
   430  	var invalid bool
   431  	transform.InspectExpressionsWithNode(n, func(node sql.Node, e sql.Expression) bool {
   432  		// If it's already invalid just skip everything else.
   433  		if invalid {
   434  			return false
   435  		}
   436  
   437  		// Interval can be used without DATE_ADD/DATE_SUB functions in CREATE/ALTER EVENTS statements.
   438  		switch node.(type) {
   439  		case *plan.CreateEvent, *plan.AlterEvent:
   440  			return false
   441  		}
   442  
   443  		switch e := e.(type) {
   444  		case *function.DateAdd, *function.DateSub:
   445  			return false
   446  		case *expression.Arithmetic:
   447  			if e.Op == "+" || e.Op == "-" {
   448  				return false
   449  			}
   450  		case *expression.Interval:
   451  			invalid = true
   452  		}
   453  
   454  		return true
   455  	})
   456  
   457  	if invalid {
   458  		return nil, transform.SameTree, analyzererrors.ErrIntervalInvalidUse.New()
   459  	}
   460  
   461  	return n, transform.SameTree, nil
   462  }
   463  
   464  func validateStarExpressions(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) {
   465  	// Validate that all occurences of the '*' placeholder expression are in a context that makes sense.
   466  	//
   467  	// That is, all uses of '*' should be either:
   468  	// - The top level of an expression.
   469  	// - The input to a COUNT or JSONARRAY function.
   470  	//
   471  	// We do not use plan.InspectExpressions here because we're treating
   472  	// the top-level expressions of sql.Node differently from subexpressions.
   473  	var err error
   474  	transform.Inspect(n, func(n sql.Node) bool {
   475  		if er, ok := n.(sql.Expressioner); ok {
   476  			for _, e := range er.Expressions() {
   477  				// An expression consisting of just a * is allowed.
   478  				if _, s := e.(*expression.Star); s {
   479  					return false
   480  				}
   481  				// Otherwise, * can only be used inside acceptable aggregation functions.
   482  				// Detect any uses of * outside such functions.
   483  				sql.Inspect(e, func(e sql.Expression) bool {
   484  					if err != nil {
   485  						return false
   486  					}
   487  					switch e.(type) {
   488  					case *expression.Star:
   489  						err = sql.ErrStarUnsupported.New()
   490  						return false
   491  					case *aggregation.Count, *aggregation.CountDistinct, *aggregation.JsonArray:
   492  						if _, s := e.Children()[0].(*expression.Star); s {
   493  							return false
   494  						}
   495  					}
   496  					return true
   497  				})
   498  			}
   499  		}
   500  		return err == nil
   501  	})
   502  	if err != nil {
   503  		return nil, transform.SameTree, err
   504  	}
   505  	return n, transform.SameTree, nil
   506  }
   507  
   508  func validateOperands(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) {
   509  	// Validate that the number of columns in an operand or a top level
   510  	// expression are as expected. The current rules are:
   511  	// * Every top level expression of a node must have 1 column.
   512  	// * The following expression nodes are allowed to have `n` columns as
   513  	// long as `n` matches:
   514  	//   * *plan.InSubquery, *expression.{Equals,NullSafeEquals,GreaterThan,LessThan,GreaterThanOrEqual,LessThanOrEqual}
   515  	// * *expression.InTuple must have a tuple on the right side, the # of
   516  	// columns for each element of the tuple must match the number of
   517  	// columns of the expression on the left.
   518  	// * Every other expression with operands must have NumColumns == 1.
   519  
   520  	// We do not use plan.InspectExpressions here because we're treating
   521  	// top-level expressions of sql.Node differently from subexpressions.
   522  	var err error
   523  	transform.Inspect(n, func(n sql.Node) bool {
   524  		if n == nil {
   525  			return false
   526  		}
   527  
   528  		if plan.IsDDLNode(n) {
   529  			return false
   530  		}
   531  
   532  		if er, ok := n.(sql.Expressioner); ok {
   533  			for _, e := range er.Expressions() {
   534  				nc := types.NumColumns(e.Type())
   535  				if nc != 1 {
   536  					if _, ok := er.(*plan.HashLookup); ok {
   537  						// hash lookup expressions are tuples with >= 1 columns
   538  						return true
   539  					}
   540  					err = sql.ErrInvalidOperandColumns.New(1, nc)
   541  					return false
   542  				}
   543  				sql.Inspect(e, func(e sql.Expression) bool {
   544  					if e == nil {
   545  						return err == nil
   546  					}
   547  					if err != nil {
   548  						return false
   549  					}
   550  					switch e.(type) {
   551  					case *plan.InSubquery, *expression.Equals, *expression.NullSafeEquals, *expression.GreaterThan,
   552  						*expression.LessThan, *expression.GreaterThanOrEqual, *expression.LessThanOrEqual:
   553  						err = types.ErrIfMismatchedColumns(e.Children()[0].Type(), e.Children()[1].Type())
   554  					case *expression.InTuple, *expression.HashInTuple:
   555  						t, ok := e.Children()[1].(expression.Tuple)
   556  						if ok && len(t.Children()) == 1 {
   557  							// A single element Tuple treats itself like the element it contains.
   558  							err = types.ErrIfMismatchedColumns(e.Children()[0].Type(), e.Children()[1].Type())
   559  						} else {
   560  							err = types.ErrIfMismatchedColumnsInTuple(e.Children()[0].Type(), e.Children()[1].Type())
   561  						}
   562  					case *aggregation.Count, *aggregation.CountDistinct, *aggregation.JsonArray:
   563  						if _, s := e.Children()[0].(*expression.Star); s {
   564  							return false
   565  						}
   566  						for _, e := range e.Children() {
   567  							nc := types.NumColumns(e.Type())
   568  							if nc != 1 {
   569  								err = sql.ErrInvalidOperandColumns.New(1, nc)
   570  							}
   571  						}
   572  					case expression.Tuple:
   573  						// Tuple expressions can contain tuples...
   574  					case *plan.ExistsSubquery:
   575  						// Any number of columns are allowed.
   576  					default:
   577  						for _, e := range e.Children() {
   578  							nc := types.NumColumns(e.Type())
   579  							if nc != 1 {
   580  								err = sql.ErrInvalidOperandColumns.New(1, nc)
   581  							}
   582  						}
   583  					}
   584  					return err == nil
   585  				})
   586  			}
   587  		}
   588  		return err == nil
   589  	})
   590  	if err != nil {
   591  		return nil, transform.SameTree, err
   592  	}
   593  	return n, transform.SameTree, nil
   594  }
   595  
   596  func validateSubqueryColumns(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) {
   597  	// Then validate that every subquery has field indexes within the correct range
   598  	// TODO: Why is this only for subqueries?
   599  
   600  	// TODO: Currently disabled.
   601  	if true {
   602  		return n, transform.SameTree, nil
   603  	}
   604  
   605  	var outOfRangeIndexExpression sql.Expression
   606  	var outOfRangeColumns int
   607  	transform.InspectExpressionsWithNode(n, func(n sql.Node, e sql.Expression) bool {
   608  		s, ok := e.(*plan.Subquery)
   609  		if !ok {
   610  			return true
   611  		}
   612  
   613  		outerScopeRowLen := len(scope.Schema()) + len(Schemas(n.Children()))
   614  		transform.Inspect(s.Query, func(n sql.Node) bool {
   615  			if n == nil {
   616  				return true
   617  			}
   618  			// TODO: the schema of the rows seen by children of
   619  			// these nodes are not reflected in the schema
   620  			// calculations here. This needs to be rationalized
   621  			// across the analyzer.
   622  			switch n := n.(type) {
   623  			case *plan.JoinNode:
   624  				return !n.Op.IsLookup()
   625  			default:
   626  			}
   627  			if es, ok := n.(sql.Expressioner); ok {
   628  				childSchemaLen := len(Schemas(n.Children()))
   629  				for _, e := range es.Expressions() {
   630  					sql.Inspect(e, func(e sql.Expression) bool {
   631  						if gf, ok := e.(*expression.GetField); ok {
   632  							if gf.Index() >= outerScopeRowLen+childSchemaLen {
   633  								outOfRangeIndexExpression = gf
   634  								outOfRangeColumns = outerScopeRowLen + childSchemaLen
   635  							}
   636  						}
   637  						return outOfRangeIndexExpression == nil
   638  					})
   639  				}
   640  			}
   641  			return outOfRangeIndexExpression == nil
   642  		})
   643  		return outOfRangeIndexExpression == nil
   644  	})
   645  	if outOfRangeIndexExpression != nil {
   646  		return nil, transform.SameTree, analyzererrors.ErrSubqueryFieldIndex.New(outOfRangeIndexExpression, outOfRangeColumns)
   647  	}
   648  
   649  	return n, transform.SameTree, nil
   650  }
   651  
   652  func stringContains(strs []string, target string) bool {
   653  	lowerTarget := strings.ToLower(target)
   654  	for _, s := range strs {
   655  		if lowerTarget == strings.ToLower(s) {
   656  			return true
   657  		}
   658  	}
   659  	return false
   660  }
   661  
   662  func tableColsContains(strs []tableCol, target tableCol) bool {
   663  	for _, s := range strs {
   664  		if s == target {
   665  			return true
   666  		}
   667  	}
   668  	return false
   669  }
   670  
   671  // validateReadOnlyDatabase invalidates queries that attempt to write to ReadOnlyDatabases.
   672  func validateReadOnlyDatabase(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) {
   673  	valid := true
   674  	var readOnlyDB sql.ReadOnlyDatabase
   675  	enforceReadOnly := scope.EnforcesReadOnly()
   676  
   677  	// if a ReadOnlyDatabase is found, invalidate the query
   678  	readOnlyDBSearch := func(node sql.Node) bool {
   679  		if rt, ok := node.(*plan.ResolvedTable); ok {
   680  			if ro, ok := rt.SqlDatabase.(sql.ReadOnlyDatabase); ok {
   681  				if ro.IsReadOnly() {
   682  					readOnlyDB = ro
   683  					valid = false
   684  				} else if enforceReadOnly {
   685  					valid = false
   686  				}
   687  			}
   688  		}
   689  		return valid
   690  	}
   691  
   692  	transform.Inspect(n, func(node sql.Node) bool {
   693  		switch n := n.(type) {
   694  		case *plan.DeleteFrom, *plan.Update, *plan.LockTables, *plan.UnlockTables:
   695  			transform.Inspect(node, readOnlyDBSearch)
   696  			return false
   697  
   698  		case *plan.InsertInto:
   699  			// ReadOnlyDatabase can be an insertion Source,
   700  			// only inspect the Destination tree
   701  			transform.Inspect(n.Destination, readOnlyDBSearch)
   702  			return false
   703  
   704  		case *plan.CreateTable:
   705  			if ro, ok := n.Database().(sql.ReadOnlyDatabase); ok {
   706  				if ro.IsReadOnly() {
   707  					readOnlyDB = ro
   708  					valid = false
   709  				} else if enforceReadOnly {
   710  					valid = false
   711  				}
   712  			}
   713  			// "CREATE TABLE ... LIKE ..." and
   714  			// "CREATE TABLE ... AS ..."
   715  			// can both use ReadOnlyDatabases as a source,
   716  			// so don't descend here.
   717  			return false
   718  
   719  		default:
   720  			// CreateTable is the only DDL node allowed
   721  			// to contain a ReadOnlyDatabase
   722  			if plan.IsDDLNode(n) {
   723  				transform.Inspect(n, readOnlyDBSearch)
   724  				return false
   725  			}
   726  		}
   727  
   728  		return valid
   729  	})
   730  	if !valid {
   731  		if enforceReadOnly {
   732  			return nil, transform.SameTree, sql.ErrProcedureCallAsOfReadOnly.New()
   733  		} else {
   734  			return nil, transform.SameTree, analyzererrors.ErrReadOnlyDatabase.New(readOnlyDB.Name())
   735  		}
   736  	}
   737  
   738  	return n, transform.SameTree, nil
   739  }
   740  
   741  // validateReadOnlyTransaction invalidates read only transactions that try to perform improper write operations.
   742  func validateReadOnlyTransaction(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) {
   743  	t := ctx.GetTransaction()
   744  
   745  	if t == nil {
   746  		return n, transform.SameTree, nil
   747  	}
   748  
   749  	// If this is a normal read write transaction don't enforce read-only. Otherwise we must prevent an invalid query.
   750  	if !t.IsReadOnly() && !scope.EnforcesReadOnly() {
   751  		return n, transform.SameTree, nil
   752  	}
   753  
   754  	valid := true
   755  
   756  	isTempTable := func(table sql.Table) bool {
   757  		tt, isTempTable := table.(sql.TemporaryTable)
   758  		if !isTempTable {
   759  			valid = false
   760  		}
   761  
   762  		return tt.IsTemporary()
   763  	}
   764  
   765  	temporaryTableSearch := func(node sql.Node) bool {
   766  		if rt, ok := node.(*plan.ResolvedTable); ok {
   767  			valid = isTempTable(rt.Table)
   768  		}
   769  		return valid
   770  	}
   771  
   772  	transform.Inspect(n, func(node sql.Node) bool {
   773  		switch n := n.(type) {
   774  		case *plan.DeleteFrom, *plan.Update, *plan.UnlockTables:
   775  			transform.Inspect(node, temporaryTableSearch)
   776  			return false
   777  		case *plan.InsertInto:
   778  			transform.Inspect(n.Destination, temporaryTableSearch)
   779  			return false
   780  		case *plan.LockTables:
   781  			// TODO: Technically we should allow for the locking of temporary tables but the LockTables implementation
   782  			// needs substantial refactoring.
   783  			valid = false
   784  			return false
   785  		case *plan.CreateTable:
   786  			// MySQL explicitly blocks the creation of temporary tables in a read only transaction.
   787  			if n.Temporary() == plan.IsTempTable {
   788  				valid = false
   789  			}
   790  
   791  			return false
   792  		default:
   793  			// DDL statements have an implicit commits which makes them valid to be executed in READ ONLY transactions.
   794  			if plan.IsDDLNode(n) {
   795  				valid = true
   796  				return false
   797  			}
   798  
   799  			return valid
   800  		}
   801  	})
   802  
   803  	if !valid {
   804  		return nil, transform.SameTree, sql.ErrReadOnlyTransaction.New()
   805  	}
   806  
   807  	return n, transform.SameTree, nil
   808  }
   809  
   810  // validateAggregations returns an error if an Aggregation expression has been used in
   811  // an invalid way, such as appearing outside of a GroupBy or Window node, or if an aggregate
   812  // function is used with the implicit all-rows grouping and contains projected expressions with
   813  // window aggregation functions that reference non-aggregated columns. Only GroupBy and Window
   814  // nodes know how to evaluate Aggregation expressions.
   815  //
   816  // See https://github.com/dolthub/go-mysql-server/issues/542 for some queries
   817  // that should be supported but that currently trigger this validation because
   818  // aggregation expressions end up in the wrong place.
   819  func validateAggregations(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) {
   820  	var validationErr error
   821  	transform.Inspect(n, func(n sql.Node) bool {
   822  		switch n := n.(type) {
   823  		case *plan.GroupBy:
   824  			validationErr = checkForAggregationFunctions(n.GroupByExprs)
   825  		case *plan.Window:
   826  			validationErr = checkForNonAggregatedColumnReferences(n)
   827  		case sql.Expressioner:
   828  			validationErr = checkForAggregationFunctions(n.Expressions())
   829  		default:
   830  		}
   831  		return validationErr == nil
   832  	})
   833  
   834  	return n, transform.SameTree, validationErr
   835  }
   836  
   837  // checkForAggregationFunctions returns an ErrAggregationUnsupported error if any aggregation
   838  // functions are found in the specified expressions.
   839  func checkForAggregationFunctions(exprs []sql.Expression) error {
   840  	var validationErr error
   841  	for _, e := range exprs {
   842  		sql.Inspect(e, func(ie sql.Expression) bool {
   843  			if _, ok := ie.(sql.Aggregation); ok {
   844  				validationErr = sql.ErrAggregationUnsupported.New(e.String())
   845  			}
   846  			return validationErr == nil
   847  		})
   848  	}
   849  	return validationErr
   850  }
   851  
   852  // checkForNonAggregatedColumnReferences returns an ErrNonAggregatedColumnWithoutGroupBy error
   853  // if an aggregate function with the implicit/all-rows grouping is mixed with aggregate window
   854  // functions that reference a non-aggregated column.
   855  // You cannot mix aggregations on the implicit/all-rows grouping with window aggregations.
   856  func checkForNonAggregatedColumnReferences(w *plan.Window) error {
   857  	for _, expr := range w.ProjectedExprs() {
   858  		if agg, ok := expr.(sql.Aggregation); ok {
   859  			if agg.Window() == nil {
   860  				index, gf := findFirstWindowAggregationColumnReference(w)
   861  
   862  				if index >= 0 {
   863  					return sql.ErrNonAggregatedColumnWithoutGroupBy.New(index, gf.String())
   864  				} else {
   865  					// We should always have an index and GetField value to use, but just in case
   866  					// something changes that, return a similar error message without those details.
   867  					return fmt.Errorf("in aggregated query without GROUP BY, expression in " +
   868  						"SELECT list contains nonaggregated column; " +
   869  						"this is incompatible with sql_mode=only_full_group_by")
   870  				}
   871  			}
   872  		}
   873  	}
   874  	return nil
   875  }
   876  
   877  // findFirstWindowAggregationColumnReference returns the index and GetField expression for the
   878  // first column reference in the first window aggregation function in the specified node's
   879  // projection expressions. If no window aggregation function with a column reference is found,
   880  // (-1, nil) is returned. This information is needed to populate an
   881  // ErrNonAggregatedColumnWithoutGroupBy error.
   882  func findFirstWindowAggregationColumnReference(w *plan.Window) (index int, gf *expression.GetField) {
   883  	for index, expr := range w.ProjectedExprs() {
   884  		var firstColumnRef *expression.GetField
   885  
   886  		transform.InspectExpr(expr, func(e sql.Expression) bool {
   887  			if windowAgg, ok := e.(sql.WindowAggregation); ok {
   888  				transform.InspectExpr(windowAgg, func(e sql.Expression) bool {
   889  					if gf, ok := e.(*expression.GetField); ok {
   890  						firstColumnRef = gf
   891  						return true
   892  					}
   893  					return false
   894  				})
   895  				return firstColumnRef != nil
   896  			}
   897  			return false
   898  		})
   899  
   900  		if firstColumnRef != nil {
   901  			return index, firstColumnRef
   902  		}
   903  	}
   904  
   905  	return -1, nil
   906  }
   907  
   908  func validateExprSem(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) {
   909  	var err error
   910  	transform.InspectExpressions(n, func(e sql.Expression) bool {
   911  		err = validateSem(e)
   912  		return err == nil
   913  	})
   914  	return n, transform.SameTree, err
   915  }
   916  
   917  // validateSem is a way to add validation logic for
   918  // specific expression types.
   919  // todo(max): Refactor and consolidate validation so it can
   920  // run before the rest of analysis. Add more expression types.
   921  // Add node equivalent.
   922  func validateSem(e sql.Expression) error {
   923  	switch e := e.(type) {
   924  	case *expression.And:
   925  		if err := logicalSem(e.BinaryExpressionStub); err != nil {
   926  			return err
   927  		}
   928  	case *expression.Or:
   929  		if err := logicalSem(e.BinaryExpressionStub); err != nil {
   930  			return err
   931  		}
   932  	default:
   933  	}
   934  	return nil
   935  }
   936  
   937  func logicalSem(e expression.BinaryExpressionStub) error {
   938  	if lc := fds(e.LeftChild); lc != 1 {
   939  		return sql.ErrInvalidOperandColumns.New(1, lc)
   940  	}
   941  	if rc := fds(e.RightChild); rc != 1 {
   942  		return sql.ErrInvalidOperandColumns.New(1, rc)
   943  	}
   944  	return nil
   945  }
   946  
   947  // fds counts the functional dependencies of an expression.
   948  // todo(max): input/output fd's should be part of the expression
   949  // interface.
   950  func fds(e sql.Expression) int {
   951  	switch e.(type) {
   952  	case *expression.UnresolvedColumn:
   953  		return 1
   954  	case *expression.UnresolvedFunction:
   955  		return 1
   956  	default:
   957  		return types.NumColumns(e.Type())
   958  	}
   959  }