github.com/dolthub/go-mysql-server@v0.18.0/sql/analyzer/validate_create_table.go (about)

     1  // Copyright 2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package analyzer
    16  
    17  import (
    18  	"strings"
    19  
    20  	"github.com/dolthub/go-mysql-server/sql"
    21  	"github.com/dolthub/go-mysql-server/sql/expression"
    22  	"github.com/dolthub/go-mysql-server/sql/plan"
    23  	"github.com/dolthub/go-mysql-server/sql/transform"
    24  	"github.com/dolthub/go-mysql-server/sql/types"
    25  )
    26  
    27  const MaxBytePrefix = 3072
    28  
    29  // validateCreateTable validates various constraints about CREATE TABLE statements. Some validation is currently done
    30  // at execution time, and should be moved here over time.
    31  func validateCreateTable(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) {
    32  	ct, ok := n.(*plan.CreateTable)
    33  	if !ok {
    34  		return n, transform.SameTree, nil
    35  	}
    36  
    37  	err := validateIdentifiers(ct.Name(), ct.TableSpec())
    38  	if err != nil {
    39  		return nil, transform.SameTree, err
    40  	}
    41  
    42  	err = validateIndexes(ctx, ct.TableSpec())
    43  	if err != nil {
    44  		return nil, transform.SameTree, err
    45  	}
    46  
    47  	err = validateNoVirtualColumnsInPrimaryKey(ct.TableSpec())
    48  	if err != nil {
    49  		return nil, transform.SameTree, err
    50  	}
    51  
    52  	// passed validateIndexes, so they all must be valid indexes
    53  	// extract map of columns that have indexes defined over them
    54  	keyedColumns := make(map[string]bool)
    55  	for _, index := range ct.TableSpec().IdxDefs {
    56  		for _, col := range index.Columns {
    57  			keyedColumns[col.Name] = true
    58  		}
    59  	}
    60  
    61  	err = validateAutoIncrementModify(ct.CreateSchema.Schema, keyedColumns)
    62  	if err != nil {
    63  		return nil, transform.SameTree, err
    64  	}
    65  
    66  	return n, transform.SameTree, nil
    67  }
    68  
    69  func validateNoVirtualColumnsInPrimaryKey(spec *plan.TableSpec) error {
    70  	for _, c := range spec.Schema.Schema {
    71  		if c.PrimaryKey && c.Virtual {
    72  			return sql.ErrVirtualColumnPrimaryKey.New()
    73  		}
    74  	}
    75  	return nil
    76  }
    77  
    78  // validateAlterTable is a set of validation functions for ALTER TABLE statements not handled by more specific
    79  // validation rules
    80  func validateAlterTable(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) {
    81  	var err error
    82  	// Inspect is required here because alter table statements with multiple clauses are represented as a block of
    83  	// plan nodes
    84  	transform.Inspect(n, func(sql.Node) bool {
    85  		switch n := n.(type) {
    86  		case *plan.RenameTable:
    87  			for _, name := range n.NewNames {
    88  				err = validateIdentifier(name)
    89  				if err != nil {
    90  					return false
    91  				}
    92  			}
    93  		case *plan.CreateCheck:
    94  			err = validateIdentifier(n.Check.Name)
    95  			if err != nil {
    96  				return false
    97  			}
    98  		case *plan.CreateForeignKey:
    99  			err = validateIdentifier(n.FkDef.Name)
   100  			if err != nil {
   101  				return false
   102  			}
   103  		}
   104  
   105  		return true
   106  	})
   107  
   108  	if err != nil {
   109  		return nil, transform.SameTree, err
   110  	}
   111  
   112  	return n, transform.SameTree, nil
   113  }
   114  
   115  // validateIdentifiers validates various constraints about identifiers in CREATE TABLE / ALTER TABLE
   116  // statements.
   117  func validateIdentifiers(name string, spec *plan.TableSpec) error {
   118  	if len(name) > sql.MaxIdentifierLength {
   119  		return sql.ErrInvalidIdentifier.New(name)
   120  	}
   121  
   122  	colNames := make(map[string]bool)
   123  	for _, col := range spec.Schema.Schema {
   124  		if len(col.Name) > sql.MaxIdentifierLength {
   125  			return sql.ErrInvalidIdentifier.New(col.Name)
   126  		}
   127  		lower := strings.ToLower(col.Name)
   128  		if colNames[lower] {
   129  			return sql.ErrDuplicateColumn.New(col.Name)
   130  		}
   131  		colNames[lower] = true
   132  	}
   133  
   134  	for _, chDef := range spec.ChDefs {
   135  		if len(chDef.Name) > sql.MaxIdentifierLength {
   136  			return sql.ErrInvalidIdentifier.New(chDef.Name)
   137  		}
   138  	}
   139  
   140  	for _, idxDef := range spec.IdxDefs {
   141  		if len(idxDef.IndexName) > sql.MaxIdentifierLength {
   142  			return sql.ErrInvalidIdentifier.New(idxDef.IndexName)
   143  		}
   144  	}
   145  
   146  	for _, fkDef := range spec.FkDefs {
   147  		if len(fkDef.Name) > sql.MaxIdentifierLength {
   148  			return sql.ErrInvalidIdentifier.New(fkDef.Name)
   149  		}
   150  	}
   151  
   152  	return nil
   153  }
   154  
   155  func resolveAlterColumn(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) {
   156  	var sch sql.Schema
   157  	var indexes []string
   158  	var validator sql.SchemaValidator
   159  	keyedColumns := make(map[string]bool)
   160  	var err error
   161  	transform.Inspect(n, func(n sql.Node) bool {
   162  		if st, ok := n.(sql.SchemaTarget); ok {
   163  			sch = st.TargetSchema()
   164  		}
   165  		switch n := n.(type) {
   166  		case *plan.ModifyColumn:
   167  			if rt, ok := n.Table.(*plan.ResolvedTable); ok {
   168  				if sv, ok := rt.UnwrappedDatabase().(sql.SchemaValidator); ok {
   169  					validator = sv
   170  				}
   171  			}
   172  			keyedColumns, err = getTableIndexColumns(ctx, n.Table)
   173  			return false
   174  		case *plan.RenameColumn:
   175  			if rt, ok := n.Table.(*plan.ResolvedTable); ok {
   176  				if sv, ok := rt.UnwrappedDatabase().(sql.SchemaValidator); ok {
   177  					validator = sv
   178  				}
   179  			}
   180  			return false
   181  		case *plan.AddColumn:
   182  			if rt, ok := n.Table.(*plan.ResolvedTable); ok {
   183  				if sv, ok := rt.UnwrappedDatabase().(sql.SchemaValidator); ok {
   184  					validator = sv
   185  				}
   186  			}
   187  			keyedColumns, err = getTableIndexColumns(ctx, n.Table)
   188  			return false
   189  		case *plan.DropColumn:
   190  			if rt, ok := n.Table.(*plan.ResolvedTable); ok {
   191  				if sv, ok := rt.UnwrappedDatabase().(sql.SchemaValidator); ok {
   192  					validator = sv
   193  				}
   194  			}
   195  			return false
   196  		case *plan.AlterIndex:
   197  			if rt, ok := n.Table.(*plan.ResolvedTable); ok {
   198  				if sv, ok := rt.UnwrappedDatabase().(sql.SchemaValidator); ok {
   199  					validator = sv
   200  				}
   201  			}
   202  			indexes, err = getTableIndexNames(ctx, a, n.Table)
   203  		default:
   204  		}
   205  		return true
   206  	})
   207  
   208  	if err != nil {
   209  		return nil, transform.SameTree, err
   210  	}
   211  
   212  	// Skip this validation if we didn't find one or more of the above node types
   213  	if len(sch) == 0 {
   214  		return n, transform.SameTree, nil
   215  	}
   216  
   217  	sch = sch.Copy() // Make a copy of the original schema to deal with any references to the original table.
   218  	initialSch := sch
   219  
   220  	addedColumn := false
   221  
   222  	// Need a TransformUp here because multiple of these statement types can be nested under a Block node.
   223  	// It doesn't look it, but this is actually an iterative loop over all the independent clauses in an ALTER statement
   224  	n, same, err := transform.Node(n, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
   225  		switch nn := n.(type) {
   226  		case *plan.ModifyColumn:
   227  			n, err := nn.WithTargetSchema(sch.Copy())
   228  			if err != nil {
   229  				return nil, transform.SameTree, err
   230  			}
   231  
   232  			sch, err = validateModifyColumn(ctx, initialSch, sch, n.(*plan.ModifyColumn), keyedColumns)
   233  			if err != nil {
   234  				return nil, transform.SameTree, err
   235  			}
   236  			return n, transform.NewTree, nil
   237  		case *plan.RenameColumn:
   238  			n, err := nn.WithTargetSchema(sch.Copy())
   239  			if err != nil {
   240  				return nil, transform.SameTree, err
   241  			}
   242  			sch, err = validateRenameColumn(initialSch, sch, n.(*plan.RenameColumn))
   243  			if err != nil {
   244  				return nil, transform.SameTree, err
   245  			}
   246  			return n, transform.NewTree, nil
   247  		case *plan.AddColumn:
   248  			n, err := nn.WithTargetSchema(sch.Copy())
   249  			if err != nil {
   250  				return nil, transform.SameTree, err
   251  			}
   252  
   253  			sch, err = validateAddColumn(initialSch, sch, n.(*plan.AddColumn))
   254  			if err != nil {
   255  				return nil, transform.SameTree, err
   256  			}
   257  
   258  			addedColumn = true
   259  			return n, transform.NewTree, nil
   260  		case *plan.DropColumn:
   261  			n, err := nn.WithTargetSchema(sch.Copy())
   262  			if err != nil {
   263  				return nil, transform.SameTree, err
   264  			}
   265  			sch, err = validateDropColumn(initialSch, sch, n.(*plan.DropColumn))
   266  			if err != nil {
   267  				return nil, transform.SameTree, err
   268  			}
   269  			delete(keyedColumns, nn.Column)
   270  
   271  			return n, transform.NewTree, nil
   272  		case *plan.AlterIndex:
   273  			n, err := nn.WithTargetSchema(sch.Copy())
   274  			if err != nil {
   275  				return nil, transform.SameTree, err
   276  			}
   277  			indexes, err = validateAlterIndex(ctx, initialSch, sch, n.(*plan.AlterIndex), indexes)
   278  			if err != nil {
   279  				return nil, transform.SameTree, err
   280  			}
   281  
   282  			keyedColumns = updateKeyedColumns(keyedColumns, nn)
   283  			return n, transform.NewTree, nil
   284  		case *plan.AlterPK:
   285  			n, err := nn.WithTargetSchema(sch.Copy())
   286  			if err != nil {
   287  				return nil, transform.SameTree, err
   288  			}
   289  			sch, err = validatePrimaryKey(ctx, initialSch, sch, n.(*plan.AlterPK))
   290  			if err != nil {
   291  				return nil, transform.SameTree, err
   292  			}
   293  			return n, transform.NewTree, nil
   294  		case *plan.AlterDefaultSet:
   295  			n, err := nn.WithTargetSchema(sch.Copy())
   296  			if err != nil {
   297  				return nil, transform.SameTree, err
   298  			}
   299  			sch, err = validateAlterDefault(initialSch, sch, n.(*plan.AlterDefaultSet))
   300  			if err != nil {
   301  				return nil, transform.SameTree, err
   302  			}
   303  			return n, transform.NewTree, nil
   304  		case *plan.AlterDefaultDrop:
   305  			n, err := nn.WithTargetSchema(sch.Copy())
   306  			if err != nil {
   307  				return nil, transform.SameTree, err
   308  			}
   309  			sch, err = validateDropDefault(initialSch, sch, n.(*plan.AlterDefaultDrop))
   310  			if err != nil {
   311  				return nil, transform.SameTree, err
   312  			}
   313  			return n, transform.NewTree, nil
   314  		}
   315  		return n, transform.SameTree, nil
   316  	})
   317  
   318  	if err != nil {
   319  		return nil, transform.SameTree, err
   320  	}
   321  
   322  	if validator != nil {
   323  		if err := validator.ValidateSchema(sch); err != nil {
   324  			return nil, transform.SameTree, err
   325  		}
   326  	}
   327  
   328  	// We can't evaluate auto-increment until the end of the analysis, since we break adding a new auto-increment unique
   329  	// column into two steps: first add the column, then create the index. If there was no index created, that's an error.
   330  	if addedColumn {
   331  		err = validateAutoIncrementAdd(ctx, sch, keyedColumns)
   332  		if err != nil {
   333  			return nil, false, err
   334  		}
   335  	}
   336  
   337  	return n, same, nil
   338  }
   339  
   340  // updateKeyedColumns updates the keyedColumns map based on the action of the AlterIndex node
   341  func updateKeyedColumns(keyedColumns map[string]bool, n *plan.AlterIndex) map[string]bool {
   342  	switch n.Action {
   343  	case plan.IndexAction_Create:
   344  		for _, col := range n.Columns {
   345  			keyedColumns[col.Name] = true
   346  		}
   347  	case plan.IndexAction_Drop:
   348  		for _, col := range n.Columns {
   349  			delete(keyedColumns, col.Name)
   350  		}
   351  	}
   352  
   353  	return keyedColumns
   354  }
   355  
   356  // validateRenameColumn checks that a DDL RenameColumn node can be safely executed (e.g. no collision with other
   357  // column names, doesn't invalidate any table check constraints).
   358  //
   359  // Note that schema is passed in twice, because one version is the initial version before the alter column expressions
   360  // are applied, and the second version is the current schema that is being modified as multiple nodes are processed.
   361  func validateRenameColumn(initialSch, sch sql.Schema, rc *plan.RenameColumn) (sql.Schema, error) {
   362  	table := rc.Table
   363  	nameable := table.(sql.Nameable)
   364  
   365  	err := validateIdentifier(rc.NewColumnName)
   366  	if err != nil {
   367  		return nil, err
   368  	}
   369  
   370  	// Check for column name collisions
   371  	if sch.Contains(rc.NewColumnName, nameable.Name()) {
   372  		return nil, sql.ErrColumnExists.New(rc.NewColumnName)
   373  	}
   374  
   375  	// Make sure this column exists. MySQL only checks the original schema, which means you can't add a column and
   376  	// rename it in the same statement. But, it also has to exist in the modified schema -- it can't have been renamed or
   377  	// dropped in this statement.
   378  	if !initialSch.Contains(rc.ColumnName, nameable.Name()) || !sch.Contains(rc.ColumnName, nameable.Name()) {
   379  		return nil, sql.ErrTableColumnNotFound.New(nameable.Name(), rc.ColumnName)
   380  	}
   381  
   382  	err = validateColumnNotUsedInCheckConstraint(rc.ColumnName, rc.Checks())
   383  	if err != nil {
   384  		return nil, err
   385  	}
   386  
   387  	return renameInSchema(sch, rc.ColumnName, rc.NewColumnName, nameable.Name()), nil
   388  }
   389  
   390  func validateAddColumn(initialSch sql.Schema, schema sql.Schema, ac *plan.AddColumn) (sql.Schema, error) {
   391  	table := ac.Table
   392  	nameable := table.(sql.Nameable)
   393  
   394  	err := validateIdentifier(ac.Column().Name)
   395  	if err != nil {
   396  		return nil, err
   397  	}
   398  
   399  	// Name collisions
   400  	if schema.Contains(ac.Column().Name, nameable.Name()) {
   401  		return nil, sql.ErrColumnExists.New(ac.Column().Name)
   402  	}
   403  
   404  	// Make sure columns named in After clause exist
   405  	idx := -1
   406  	if ac.Order() != nil && ac.Order().AfterColumn != "" {
   407  		afterColumn := ac.Order().AfterColumn
   408  		idx = schema.IndexOf(afterColumn, nameable.Name())
   409  		if idx < 0 {
   410  			return nil, sql.ErrTableColumnNotFound.New(nameable.Name(), afterColumn)
   411  		}
   412  	}
   413  
   414  	newSch := make(sql.Schema, 0, len(schema)+1)
   415  	if idx >= 0 {
   416  		newSch = append(newSch, schema[:idx+1]...)
   417  		newSch = append(newSch, ac.Column().Copy())
   418  		newSch = append(newSch, schema[idx+1:]...)
   419  	} else { // new column at end
   420  		newSch = append(newSch, schema...)
   421  		newSch = append(newSch, ac.Column().Copy())
   422  	}
   423  
   424  	return newSch, nil
   425  }
   426  
   427  // isStrictMysqlCompatibilityEnabled returns true if the strict_mysql_compatibility SQL system variable has been
   428  // turned on in this session, otherwise it returns false, or any unexpected error querying the system variable.
   429  func isStrictMysqlCompatibilityEnabled(ctx *sql.Context) (bool, error) {
   430  	strictMysqlCompatibility, err := ctx.GetSessionVariable(ctx, "strict_mysql_compatibility")
   431  	if err != nil {
   432  		return false, err
   433  	}
   434  	i, ok := strictMysqlCompatibility.(int8)
   435  	if !ok {
   436  		return false, nil
   437  	}
   438  	return i == 1, nil
   439  }
   440  
   441  func validateModifyColumn(ctx *sql.Context, initialSch sql.Schema, schema sql.Schema, mc *plan.ModifyColumn, keyedColumns map[string]bool) (sql.Schema, error) {
   442  	table := mc.Table
   443  	nameable := table.(sql.Nameable)
   444  
   445  	err := validateIdentifier(mc.NewColumn().Name)
   446  	if err != nil {
   447  		return nil, err
   448  	}
   449  
   450  	// Look for the old column and throw an error if it's not there. The column cannot have been renamed in the same
   451  	// statement. This matches the MySQL behavior.
   452  	if !schema.Contains(mc.Column(), nameable.Name()) ||
   453  		!initialSch.Contains(mc.Column(), nameable.Name()) {
   454  		return nil, sql.ErrTableColumnNotFound.New(nameable.Name(), mc.Column())
   455  	}
   456  
   457  	newSch := replaceInSchema(schema, mc.NewColumn(), nameable.Name())
   458  
   459  	err = validateAutoIncrementModify(newSch, keyedColumns)
   460  	if err != nil {
   461  		return nil, err
   462  	}
   463  
   464  	// TODO: When a column is being modified, we should ideally check that any existing table check constraints
   465  	//       are still valid (e.g. if the column type changed) and throw an error if they are invalidated.
   466  	//       That would be consistent with MySQL behavior.
   467  
   468  	// not becoming a text/blob column
   469  	newCol := mc.NewColumn()
   470  	if !types.IsTextBlob(newCol.Type) {
   471  		return newSch, nil
   472  	}
   473  
   474  	strictMysqlCompatibility, err := isStrictMysqlCompatibilityEnabled(ctx)
   475  	if err != nil {
   476  		return nil, err
   477  	}
   478  
   479  	// any indexes that use this column must have a prefix length
   480  	ia, err := newIndexAnalyzerForNode(ctx, table)
   481  	if err != nil {
   482  		return nil, err
   483  	}
   484  	indexes := ia.IndexesByTable(ctx, ctx.GetCurrentDatabase(), getTableName(table))
   485  	for _, index := range indexes {
   486  		if index.IsFullText() {
   487  			continue
   488  		}
   489  		prefixLengths := index.PrefixLengths()
   490  		for i, expr := range index.Expressions() {
   491  			col := plan.GetColumnFromIndexExpr(expr, getTable(table))
   492  			if col.Name == mc.Column() {
   493  				if len(prefixLengths) == 0 || prefixLengths[i] == 0 {
   494  					// MariaDB allows BLOB and TEXT columns to be used in unique keys WITHOUT specifying
   495  					// a prefix length, but MySQL does not, so still throw an error if we are in strict
   496  					// MySQL compatibility mode.
   497  					if !index.IsFullText() && (!index.IsUnique() && !strictMysqlCompatibility) {
   498  						return nil, sql.ErrInvalidBlobTextKey.New(col.Name)
   499  					}
   500  				}
   501  				if types.IsTextOnly(newCol.Type) && len(prefixLengths) > 0 && prefixLengths[i]*4 > MaxBytePrefix {
   502  					return nil, sql.ErrKeyTooLong.New()
   503  				}
   504  			}
   505  		}
   506  	}
   507  
   508  	return newSch, nil
   509  }
   510  
   511  func validateIdentifier(name string) error {
   512  	if len(name) > sql.MaxIdentifierLength {
   513  		return sql.ErrInvalidIdentifier.New(name)
   514  	}
   515  	return nil
   516  }
   517  
   518  func validateDropColumn(initialSch, sch sql.Schema, dc *plan.DropColumn) (sql.Schema, error) {
   519  	table := dc.Table
   520  	nameable := table.(sql.Nameable)
   521  
   522  	// Look for the column to be dropped and throw an error if it's not there. It must exist in the original schema before
   523  	// this statement was run, it cannot have been added as part of this ALTER TABLE statement. This matches the MySQL
   524  	// behavior.
   525  	if !initialSch.Contains(dc.Column, nameable.Name()) || !sch.Contains(dc.Column, nameable.Name()) {
   526  		return nil, sql.ErrTableColumnNotFound.New(nameable.Name(), dc.Column)
   527  	}
   528  
   529  	err := validateColumnSafeToDropWithCheckConstraint(dc.Column, dc.Checks())
   530  	if err != nil {
   531  		return nil, err
   532  	}
   533  
   534  	newSch := removeInSchema(sch, dc.Column, nameable.Name())
   535  
   536  	return newSch, nil
   537  }
   538  
   539  // validateColumnNotUsedInCheckConstraint validates that the specified column name is not referenced in any of
   540  // the specified table check constraints.
   541  func validateColumnNotUsedInCheckConstraint(columnName string, checks sql.CheckConstraints) error {
   542  	var err error
   543  	for _, check := range checks {
   544  		_ = transform.InspectExpr(check.Expr, func(e sql.Expression) bool {
   545  			var name string
   546  			switch e := e.(type) {
   547  			case *expression.UnresolvedColumn:
   548  				name = e.Name()
   549  			case *expression.GetField:
   550  				name = e.Name()
   551  			default:
   552  				return false
   553  			}
   554  			if strings.EqualFold(name, columnName) {
   555  				err = sql.ErrCheckConstraintInvalidatedByColumnAlter.New(columnName, check.Name)
   556  				return true
   557  			}
   558  			return false
   559  		})
   560  
   561  		if err != nil {
   562  			return err
   563  		}
   564  	}
   565  	return nil
   566  }
   567  
   568  // validateColumnSafeToDropWithCheckConstraint validates that the specified column name is safe to drop, even if
   569  // referenced in a check constraint. Columns referenced in check constraints can be dropped if they are the only
   570  // column referenced in the check constraint.
   571  func validateColumnSafeToDropWithCheckConstraint(columnName string, checks sql.CheckConstraints) error {
   572  	var err error
   573  	for _, check := range checks {
   574  		hasOtherCol := false
   575  		hasMatchingCol := false
   576  		_ = transform.InspectExpr(check.Expr, func(e sql.Expression) bool {
   577  			var colName string
   578  			switch e := e.(type) {
   579  			case *expression.UnresolvedColumn:
   580  				colName = e.Name()
   581  			case *expression.GetField:
   582  				colName = e.Name()
   583  			default:
   584  				return false
   585  			}
   586  			if strings.EqualFold(columnName, colName) {
   587  				if hasOtherCol {
   588  					err = sql.ErrCheckConstraintInvalidatedByColumnAlter.New(columnName, check.Name)
   589  					return true
   590  				} else {
   591  					hasMatchingCol = true
   592  				}
   593  			} else {
   594  				hasOtherCol = true
   595  			}
   596  			return false
   597  		})
   598  
   599  		if hasOtherCol && hasMatchingCol {
   600  			err = sql.ErrCheckConstraintInvalidatedByColumnAlter.New(columnName, check.Name)
   601  		}
   602  
   603  		if err != nil {
   604  			return err
   605  		}
   606  	}
   607  	return nil
   608  }
   609  
   610  // validateAlterIndex validates the specified column can have an index added, dropped, or renamed. Returns an updated
   611  // list of index name given the add, drop, or rename operations.
   612  func validateAlterIndex(ctx *sql.Context, initialSch, sch sql.Schema, ai *plan.AlterIndex, indexes []string) ([]string, error) {
   613  	tableName := getTableName(ai.Table)
   614  
   615  	switch ai.Action {
   616  	case plan.IndexAction_Create:
   617  		err := validateIdentifier(ai.IndexName)
   618  		if err != nil {
   619  			return nil, err
   620  		}
   621  
   622  		badColName, ok := missingIdxColumn(ai.Columns, sch, tableName)
   623  		if !ok {
   624  			return nil, sql.ErrKeyColumnDoesNotExist.New(badColName)
   625  		}
   626  		err = validateIndexType(ctx, ai.Columns, sch, ai.Constraint)
   627  		if err != nil {
   628  			return nil, err
   629  		}
   630  
   631  		if ai.Constraint == sql.IndexConstraint_Spatial {
   632  			if len(ai.Columns) != 1 {
   633  				return nil, sql.ErrTooManyKeyParts.New(1)
   634  			}
   635  			idx := sch.IndexOfColName(ai.Columns[0].Name)
   636  			if idx == -1 {
   637  				return nil, sql.ErrColumnNotFound.New(ai.Columns[0].Name)
   638  			}
   639  			schCol := sch[idx]
   640  			spatialCol, ok := schCol.Type.(sql.SpatialColumnType)
   641  			if !ok {
   642  				return nil, sql.ErrBadSpatialIdxCol.New()
   643  			}
   644  			if schCol.Nullable {
   645  				return nil, sql.ErrNullableSpatialIdx.New()
   646  			}
   647  			if _, ok = spatialCol.GetSpatialTypeSRID(); !ok {
   648  				ctx.Warn(3674, "The spatial index on column '%s' will not be used by the query optimizer since the column does not have an SRID attribute. Consider adding an SRID attribyte to the column.", schCol.Name)
   649  			}
   650  		}
   651  
   652  		return append(indexes, ai.IndexName), nil
   653  	case plan.IndexAction_Drop:
   654  		savedIdx := -1
   655  		for i, idx := range indexes {
   656  			if strings.EqualFold(idx, ai.IndexName) {
   657  				savedIdx = i
   658  				break
   659  			}
   660  		}
   661  
   662  		if savedIdx == -1 {
   663  			return nil, sql.ErrCantDropFieldOrKey.New(ai.IndexName)
   664  		}
   665  
   666  		// Remove the index from the list
   667  		return append(indexes[:savedIdx], indexes[savedIdx+1:]...), nil
   668  	case plan.IndexAction_Rename:
   669  		err := validateIdentifier(ai.IndexName)
   670  		if err != nil {
   671  			return nil, err
   672  		}
   673  
   674  		savedIdx := -1
   675  		for i, idx := range indexes {
   676  			if strings.EqualFold(idx, ai.PreviousIndexName) {
   677  				savedIdx = i
   678  			}
   679  		}
   680  
   681  		if savedIdx == -1 {
   682  			return nil, sql.ErrCantDropFieldOrKey.New(ai.IndexName)
   683  		}
   684  
   685  		// Simulate the rename by deleting the old name and adding the new one.
   686  		return append(append(indexes[:savedIdx], indexes[savedIdx+1:]...), ai.IndexName), nil
   687  	}
   688  
   689  	return indexes, nil
   690  }
   691  
   692  // validatePrefixLength handles all errors related to creating indexes with prefix lengths
   693  func validatePrefixLength(ctx *sql.Context, schCol *sql.Column, idxCol sql.IndexColumn, constraint sql.IndexConstraint) error {
   694  	isFullText := constraint == sql.IndexConstraint_Fulltext
   695  	isUnique := constraint == sql.IndexConstraint_Unique
   696  
   697  	// Prefix length is ignored for full text indexes
   698  	if isFullText {
   699  		return nil
   700  	}
   701  
   702  	// Throw prefix length error for non-string types with prefixes
   703  	if idxCol.Length > 0 && !types.IsText(schCol.Type) {
   704  		return sql.ErrInvalidIndexPrefix.New(schCol.Name)
   705  	}
   706  
   707  	// Get prefix key length in bytes, so times 4 for varchar, text, and varchar
   708  	prefixByteLength := idxCol.Length
   709  	if types.IsTextOnly(schCol.Type) {
   710  		prefixByteLength = 4 * idxCol.Length
   711  	}
   712  
   713  	// Prefix length is longer than max
   714  	if prefixByteLength > MaxBytePrefix {
   715  		return sql.ErrKeyTooLong.New()
   716  	}
   717  
   718  	// The specified prefix length is longer than the column
   719  	maxByteLength := int64(schCol.Type.MaxTextResponseByteLength(ctx))
   720  	if prefixByteLength > maxByteLength {
   721  		return sql.ErrInvalidIndexPrefix.New(schCol.Name)
   722  	}
   723  
   724  	strictMysqlCompatibility, err := isStrictMysqlCompatibilityEnabled(ctx)
   725  	if err != nil {
   726  		return err
   727  	}
   728  
   729  	// Prefix length is required for BLOB and TEXT columns.
   730  	if types.IsTextBlob(schCol.Type) && prefixByteLength == 0 {
   731  		// MariaDB extends this behavior so that unique indexes don't require a prefix length.
   732  		if !isUnique || strictMysqlCompatibility {
   733  			return sql.ErrInvalidBlobTextKey.New(schCol.Name)
   734  		}
   735  
   736  		// The hash we compute doesn't take into account the collation settings of the column, so in a
   737  		// case-insensitive collation, although "YES" and "yes" are equivalent, they will still generate
   738  		// different hashes which won't correctly identify a real uniqueness constraint violation.
   739  		stringType, ok := schCol.Type.(types.StringType)
   740  		if ok {
   741  			collation := stringType.Collation().Collation()
   742  			if !collation.IsCaseSensitive || !collation.IsAccentSensitive {
   743  				return sql.ErrCollationNotSupportedOnUniqueTextIndex.New()
   744  			}
   745  		}
   746  	}
   747  
   748  	return nil
   749  }
   750  
   751  // validateIndexType prevents creating invalid indexes
   752  func validateIndexType(ctx *sql.Context, cols []sql.IndexColumn, sch sql.Schema, constraint sql.IndexConstraint) error {
   753  	for _, idxCol := range cols {
   754  		idx := sch.IndexOfColName(idxCol.Name)
   755  		if idx == -1 {
   756  			return sql.ErrColumnNotFound.New(idxCol.Name)
   757  		}
   758  		schCol := sch[idx]
   759  		err := validatePrefixLength(ctx, schCol, idxCol, constraint)
   760  		if err != nil {
   761  			return err
   762  		}
   763  	}
   764  	return nil
   765  }
   766  
   767  // missingIdxColumn takes in a set of IndexColumns and returns false, along with the offending column name, if
   768  // an index Column is not in an index.
   769  func missingIdxColumn(cols []sql.IndexColumn, sch sql.Schema, tableName string) (string, bool) {
   770  	for _, c := range cols {
   771  		if ok := sch.Contains(c.Name, tableName); !ok {
   772  			return c.Name, false
   773  		}
   774  	}
   775  
   776  	return "", true
   777  }
   778  
   779  func replaceInSchema(sch sql.Schema, col *sql.Column, tableName string) sql.Schema {
   780  	idx := sch.IndexOf(col.Name, tableName)
   781  	schCopy := make(sql.Schema, len(sch))
   782  	for i := range sch {
   783  		if i == idx {
   784  			cc := *col
   785  			// Some information about the column is not specified in a MODIFY COLUMN statement, such as being a key
   786  			cc.PrimaryKey = sch[i].PrimaryKey
   787  			cc.Source = sch[i].Source
   788  			if cc.PrimaryKey {
   789  				cc.Nullable = false
   790  			}
   791  
   792  			schCopy[i] = &cc
   793  
   794  		} else {
   795  			cc := *sch[i]
   796  			schCopy[i] = &cc
   797  		}
   798  	}
   799  	return schCopy
   800  }
   801  
   802  func renameInSchema(sch sql.Schema, oldColName, newColName, tableName string) sql.Schema {
   803  	idx := sch.IndexOf(oldColName, tableName)
   804  	schCopy := make(sql.Schema, len(sch))
   805  	for i := range sch {
   806  		if i == idx {
   807  			cc := *sch[i]
   808  			cc.Name = newColName
   809  			schCopy[i] = &cc
   810  		} else {
   811  			cc := *sch[i]
   812  			schCopy[i] = &cc
   813  		}
   814  	}
   815  	return schCopy
   816  }
   817  
   818  func removeInSchema(sch sql.Schema, colName, tableName string) sql.Schema {
   819  	idx := sch.IndexOf(colName, tableName)
   820  	if idx == -1 {
   821  		return sch
   822  	}
   823  
   824  	schCopy := make(sql.Schema, len(sch)-1)
   825  	for i := range sch {
   826  		if i < idx {
   827  			cc := *sch[i]
   828  			schCopy[i] = &cc
   829  		} else if i > idx {
   830  			cc := *sch[i]
   831  			schCopy[i-1] = &cc // We want to shift stuff over.
   832  		}
   833  	}
   834  	return schCopy
   835  }
   836  
   837  // TODO: make this work for CREATE TABLE statements where there's a non-pk auto increment column
   838  func validateAutoIncrementModify(schema sql.Schema, keyedColumns map[string]bool) error {
   839  	seen := false
   840  	for _, col := range schema {
   841  		if col.AutoIncrement {
   842  			// keyedColumns == nil means they are trying to add auto_increment column
   843  			if !col.PrimaryKey && !keyedColumns[col.Name] {
   844  				// AUTO_INCREMENT col must be a key
   845  				return sql.ErrInvalidAutoIncCols.New()
   846  			}
   847  			if col.Default != nil {
   848  				// AUTO_INCREMENT col cannot have default
   849  				return sql.ErrInvalidAutoIncCols.New()
   850  			}
   851  			if seen {
   852  				// there can be at most one AUTO_INCREMENT col
   853  				return sql.ErrInvalidAutoIncCols.New()
   854  			}
   855  			seen = true
   856  		}
   857  	}
   858  	return nil
   859  }
   860  
   861  func validateAutoIncrementAdd(ctx *sql.Context, schema sql.Schema, keyColumns map[string]bool) error {
   862  	seen := false
   863  	for _, col := range schema {
   864  		if col.AutoIncrement {
   865  			{
   866  				if !col.PrimaryKey && !keyColumns[col.Name] {
   867  					// AUTO_INCREMENT col must be a key
   868  					return sql.ErrInvalidAutoIncCols.New()
   869  				}
   870  				if col.Default != nil {
   871  					// AUTO_INCREMENT col cannot have default
   872  					return sql.ErrInvalidAutoIncCols.New()
   873  				}
   874  				if seen {
   875  					// there can be at most one AUTO_INCREMENT col
   876  					return sql.ErrInvalidAutoIncCols.New()
   877  				}
   878  				seen = true
   879  			}
   880  		}
   881  	}
   882  	return nil
   883  }
   884  
   885  const textIndexPrefix = 1000
   886  
   887  // validateIndexes prevents creating tables with blob/text primary keys and indexes without a specified length
   888  // TODO: this method is very similar to validateIndexType...
   889  func validateIndexes(ctx *sql.Context, tableSpec *plan.TableSpec) error {
   890  	lwrNames := make(map[string]*sql.Column)
   891  	for _, col := range tableSpec.Schema.Schema {
   892  		lwrNames[strings.ToLower(col.Name)] = col
   893  	}
   894  	var hasPkIndexDef bool
   895  	for _, idx := range tableSpec.IdxDefs {
   896  		if idx.Constraint == sql.IndexConstraint_Primary {
   897  			hasPkIndexDef = true
   898  		}
   899  		for _, idxCol := range idx.Columns {
   900  			schCol, ok := lwrNames[strings.ToLower(idxCol.Name)]
   901  			if !ok {
   902  				return sql.ErrUnknownIndexColumn.New(idxCol.Name, idx.IndexName)
   903  			}
   904  			err := validatePrefixLength(ctx, schCol, idxCol, idx.Constraint)
   905  			if err != nil {
   906  				return err
   907  			}
   908  		}
   909  		if idx.Constraint == sql.IndexConstraint_Spatial {
   910  			if len(idx.Columns) != 1 {
   911  				return sql.ErrTooManyKeyParts.New(1)
   912  			}
   913  			schCol, _ := lwrNames[strings.ToLower(idx.Columns[0].Name)]
   914  			spatialCol, ok := schCol.Type.(sql.SpatialColumnType)
   915  			if !ok {
   916  				return sql.ErrBadSpatialIdxCol.New()
   917  			}
   918  			if schCol.Nullable {
   919  				return sql.ErrNullableSpatialIdx.New()
   920  			}
   921  			if _, ok = spatialCol.GetSpatialTypeSRID(); !ok {
   922  				ctx.Warn(3674, "The spatial index on column '%s' will not be used by the query optimizer since the column does not have an SRID attribute. Consider adding an SRID attribyte to the column.", schCol.Name)
   923  			}
   924  		}
   925  	}
   926  
   927  	// if there was not a PkIndexDef, then any primary key text/blob columns must not have index lengths
   928  	// otherwise, then it would've been validated before this
   929  	if !hasPkIndexDef {
   930  		for _, col := range tableSpec.Schema.Schema {
   931  			if col.PrimaryKey && types.IsTextBlob(col.Type) {
   932  				return sql.ErrInvalidBlobTextKey.New(col.Name)
   933  			}
   934  		}
   935  	}
   936  	return nil
   937  }
   938  
   939  // getTableIndexColumns returns the columns over which indexes are defined
   940  func getTableIndexColumns(ctx *sql.Context, table sql.Node) (map[string]bool, error) {
   941  	ia, err := newIndexAnalyzerForNode(ctx, table)
   942  	if err != nil {
   943  		return nil, err
   944  	}
   945  
   946  	keyedColumns := make(map[string]bool)
   947  	indexes := ia.IndexesByTable(ctx, ctx.GetCurrentDatabase(), getTableName(table))
   948  	for _, index := range indexes {
   949  		for _, expr := range index.Expressions() {
   950  			if col := plan.GetColumnFromIndexExpr(expr, getTable(table)); col != nil {
   951  				keyedColumns[col.Name] = true
   952  			}
   953  		}
   954  	}
   955  
   956  	return keyedColumns, nil
   957  }
   958  
   959  // getTableIndexNames returns the names of indexes associated with a table.
   960  func getTableIndexNames(ctx *sql.Context, _ *Analyzer, table sql.Node) ([]string, error) {
   961  	ia, err := newIndexAnalyzerForNode(ctx, table)
   962  	if err != nil {
   963  		return nil, err
   964  	}
   965  
   966  	indexes := ia.IndexesByTable(ctx, ctx.GetCurrentDatabase(), getTableName(table))
   967  	names := make([]string, len(indexes))
   968  
   969  	for i, index := range indexes {
   970  		names[i] = index.ID()
   971  	}
   972  
   973  	if hasPrimaryKeys(table.Schema()) {
   974  		names = append(names, "PRIMARY")
   975  	}
   976  
   977  	return names, nil
   978  }
   979  
   980  // validatePrimaryKey validates a primary key add or drop operation.
   981  func validatePrimaryKey(ctx *sql.Context, initialSch, sch sql.Schema, ai *plan.AlterPK) (sql.Schema, error) {
   982  	tableName := getTableName(ai.Table)
   983  	switch ai.Action {
   984  	case plan.PrimaryKeyAction_Create:
   985  		badColName, ok := missingIdxColumn(ai.Columns, sch, tableName)
   986  		if !ok {
   987  			return nil, sql.ErrKeyColumnDoesNotExist.New(badColName)
   988  		}
   989  
   990  		if hasPrimaryKeys(sch) {
   991  			return nil, sql.ErrMultiplePrimaryKeysDefined.New()
   992  		}
   993  
   994  		for _, idxCol := range ai.Columns {
   995  			schCol := sch[sch.IndexOf(idxCol.Name, tableName)]
   996  			err := validatePrefixLength(ctx, schCol, idxCol, sql.IndexConstraint_Primary)
   997  			if err != nil {
   998  				return nil, err
   999  			}
  1000  
  1001  			if schCol.Virtual {
  1002  				return nil, sql.ErrVirtualColumnPrimaryKey.New()
  1003  			}
  1004  		}
  1005  
  1006  		// Set the primary keys
  1007  		for _, col := range ai.Columns {
  1008  			sch[sch.IndexOf(col.Name, tableName)].PrimaryKey = true
  1009  		}
  1010  
  1011  		return sch, nil
  1012  	case plan.PrimaryKeyAction_Drop:
  1013  		if !hasPrimaryKeys(sch) {
  1014  			return nil, sql.ErrCantDropFieldOrKey.New("PRIMARY")
  1015  		}
  1016  
  1017  		for _, col := range sch {
  1018  			if col.PrimaryKey {
  1019  				col.PrimaryKey = false
  1020  			}
  1021  		}
  1022  
  1023  		return sch, nil
  1024  	default:
  1025  		return sch, nil
  1026  	}
  1027  }
  1028  
  1029  // validateAlterDefault validates the addition of a default value to a column.
  1030  func validateAlterDefault(initialSch, sch sql.Schema, as *plan.AlterDefaultSet) (sql.Schema, error) {
  1031  	idx := sch.IndexOf(as.ColumnName, getTableName(as.Table))
  1032  	if idx == -1 {
  1033  		return nil, sql.ErrTableColumnNotFound.New(as.ColumnName)
  1034  	}
  1035  
  1036  	copiedDefault, err := as.Default.WithChildren(as.Default.Children()...)
  1037  	if err != nil {
  1038  		return nil, err
  1039  	}
  1040  
  1041  	sch[idx].Default = copiedDefault.(*sql.ColumnDefaultValue)
  1042  
  1043  	return sch, err
  1044  }
  1045  
  1046  // validateDropDefault validates the dropping of a default value.
  1047  func validateDropDefault(initialSch, sch sql.Schema, ad *plan.AlterDefaultDrop) (sql.Schema, error) {
  1048  	idx := sch.IndexOf(ad.ColumnName, getTableName(ad.Table))
  1049  	if idx == -1 {
  1050  		return nil, sql.ErrTableColumnNotFound.New(ad.ColumnName)
  1051  	}
  1052  
  1053  	sch[idx].Default = nil
  1054  
  1055  	return sch, nil
  1056  }
  1057  
  1058  func hasPrimaryKeys(sch sql.Schema) bool {
  1059  	for _, c := range sch {
  1060  		if c.PrimaryKey {
  1061  			return true
  1062  		}
  1063  	}
  1064  
  1065  	return false
  1066  }