github.com/dolthub/go-mysql-server@v0.18.0/sql/plan/alter_table.go (about)

     1  // Copyright 2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package plan
    16  
    17  import (
    18  	"fmt"
    19  	"strings"
    20  
    21  	"github.com/dolthub/go-mysql-server/sql"
    22  	"github.com/dolthub/go-mysql-server/sql/expression"
    23  	"github.com/dolthub/go-mysql-server/sql/transform"
    24  	"github.com/dolthub/go-mysql-server/sql/types"
    25  )
    26  
    27  type RenameTable struct {
    28  	ddlNode
    29  	OldNames    []string
    30  	NewNames    []string
    31  	alterTblDef bool
    32  }
    33  
    34  var _ sql.Node = (*RenameTable)(nil)
    35  var _ sql.Databaser = (*RenameTable)(nil)
    36  var _ sql.CollationCoercible = (*RenameTable)(nil)
    37  
    38  // NewRenameTable creates a new RenameTable node
    39  func NewRenameTable(db sql.Database, oldNames, newNames []string, alterTbl bool) *RenameTable {
    40  	return &RenameTable{
    41  		ddlNode:     ddlNode{db},
    42  		OldNames:    oldNames,
    43  		NewNames:    newNames,
    44  		alterTblDef: alterTbl,
    45  	}
    46  }
    47  
    48  func (r *RenameTable) WithDatabase(db sql.Database) (sql.Node, error) {
    49  	nr := *r
    50  	nr.Db = db
    51  	return &nr, nil
    52  }
    53  
    54  func (r *RenameTable) String() string {
    55  	return fmt.Sprintf("Rename table %s to %s", r.OldNames, r.NewNames)
    56  }
    57  
    58  func (r *RenameTable) IsReadOnly() bool {
    59  	return false
    60  }
    61  
    62  func (r *RenameTable) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error) {
    63  	renamer, _ := r.Db.(sql.TableRenamer)
    64  	viewDb, _ := r.Db.(sql.ViewDatabase)
    65  	viewRegistry := ctx.GetViewRegistry()
    66  
    67  	for i, oldName := range r.OldNames {
    68  		if tbl, exists := r.tableExists(ctx, oldName); exists {
    69  			err := r.renameTable(ctx, renamer, tbl, oldName, r.NewNames[i])
    70  			if err != nil {
    71  				return nil, err
    72  			}
    73  		} else {
    74  			success, err := r.renameView(ctx, viewDb, viewRegistry, oldName, r.NewNames[i])
    75  			if err != nil {
    76  				return nil, err
    77  			} else if !success {
    78  				return nil, sql.ErrTableNotFound.New(oldName)
    79  			}
    80  		}
    81  	}
    82  
    83  	return sql.RowsToRowIter(sql.NewRow(types.NewOkResult(0))), nil
    84  }
    85  
    86  func (r *RenameTable) WithChildren(children ...sql.Node) (sql.Node, error) {
    87  	return NillaryWithChildren(r, children...)
    88  }
    89  
    90  // CheckPrivileges implements the interface sql.Node.
    91  func (r *RenameTable) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool {
    92  	var operations []sql.PrivilegedOperation
    93  	for _, oldName := range r.OldNames {
    94  		subject := sql.PrivilegeCheckSubject{
    95  			Database: CheckPrivilegeNameForDatabase(r.Db),
    96  			Table:    oldName,
    97  		}
    98  		operations = append(operations, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Alter, sql.PrivilegeType_Drop))
    99  	}
   100  	for _, newName := range r.NewNames {
   101  		subject := sql.PrivilegeCheckSubject{
   102  			Database: CheckPrivilegeNameForDatabase(r.Db),
   103  			Table:    newName,
   104  		}
   105  		operations = append(operations, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Create, sql.PrivilegeType_Insert))
   106  	}
   107  	return opChecker.UserHasPrivileges(ctx, operations...)
   108  }
   109  
   110  // CollationCoercibility implements the interface sql.CollationCoercible.
   111  func (*RenameTable) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   112  	return sql.Collation_binary, 7
   113  }
   114  
   115  func (r *RenameTable) tableExists(ctx *sql.Context, name string) (sql.Table, bool) {
   116  	tbl, ok, err := r.Db.GetTableInsensitive(ctx, name)
   117  	if err != nil || !ok {
   118  		return nil, false
   119  	}
   120  	return tbl, true
   121  }
   122  
   123  func (r *RenameTable) renameTable(ctx *sql.Context, renamer sql.TableRenamer, tbl sql.Table, oldName, newName string) error {
   124  	if renamer == nil {
   125  		return sql.ErrRenameTableNotSupported.New(r.Db.Name())
   126  	}
   127  
   128  	if fkTable, ok := tbl.(sql.ForeignKeyTable); ok {
   129  		parentFks, err := fkTable.GetReferencedForeignKeys(ctx)
   130  		if err != nil {
   131  			return err
   132  		}
   133  		for _, parentFk := range parentFks {
   134  			//TODO: support renaming tables across databases for foreign keys
   135  			if strings.ToLower(parentFk.Database) != strings.ToLower(parentFk.ParentDatabase) {
   136  				return fmt.Errorf("updating foreign key table names across databases is not yet supported")
   137  			}
   138  			parentFk.ParentTable = newName
   139  			childTbl, ok, err := r.Db.GetTableInsensitive(ctx, parentFk.Table)
   140  			if err != nil {
   141  				return err
   142  			}
   143  			if !ok {
   144  				return sql.ErrTableNotFound.New(parentFk.Table)
   145  			}
   146  			childFkTbl, ok := childTbl.(sql.ForeignKeyTable)
   147  			if !ok {
   148  				return fmt.Errorf("referenced table `%s` supports foreign keys but declaring table `%s` does not", parentFk.ParentTable, parentFk.Table)
   149  			}
   150  			err = childFkTbl.UpdateForeignKey(ctx, parentFk.Name, parentFk)
   151  			if err != nil {
   152  				return err
   153  			}
   154  		}
   155  
   156  		fks, err := fkTable.GetDeclaredForeignKeys(ctx)
   157  		if err != nil {
   158  			return err
   159  		}
   160  		for _, fk := range fks {
   161  			fk.Table = newName
   162  			err = fkTable.UpdateForeignKey(ctx, fk.Name, fk)
   163  			if err != nil {
   164  				return err
   165  			}
   166  		}
   167  	}
   168  
   169  	err := renamer.RenameTable(ctx, oldName, newName)
   170  	if err != nil {
   171  		return err
   172  	}
   173  
   174  	return nil
   175  }
   176  
   177  func (r *RenameTable) renameView(ctx *sql.Context, viewDb sql.ViewDatabase, vr *sql.ViewRegistry, oldName, newName string) (bool, error) {
   178  	if viewDb != nil {
   179  		oldView, exists, err := viewDb.GetViewDefinition(ctx, oldName)
   180  		if err != nil {
   181  			return false, err
   182  		} else if !exists {
   183  			return false, nil
   184  		}
   185  
   186  		if r.alterTblDef {
   187  			return false, sql.ErrExpectedTableFoundView.New(fmt.Sprintf("'%s.%s'", r.Db.Name(), oldName))
   188  		}
   189  
   190  		err = viewDb.DropView(ctx, oldName)
   191  		if err != nil {
   192  			return false, err
   193  		}
   194  
   195  		err = viewDb.CreateView(ctx, newName, oldView.TextDefinition, oldView.CreateViewStatement)
   196  		if err != nil {
   197  			return false, err
   198  		}
   199  
   200  		return true, nil
   201  	} else {
   202  		view, exists := vr.View(r.Db.Name(), oldName)
   203  		if !exists {
   204  			return false, nil
   205  		}
   206  
   207  		if r.alterTblDef {
   208  			return false, sql.ErrExpectedTableFoundView.New(fmt.Sprintf("'%s.%s'", r.Db.Name(), oldName))
   209  		}
   210  
   211  		err := vr.Delete(r.Db.Name(), oldName)
   212  		if err != nil {
   213  			return false, nil
   214  		}
   215  		err = vr.Register(r.Db.Name(), sql.NewView(newName, view.Definition(), view.TextDefinition(), view.CreateStatement()))
   216  		if err != nil {
   217  			return false, nil
   218  		}
   219  		return true, nil
   220  	}
   221  }
   222  
   223  type AddColumn struct {
   224  	ddlNode
   225  	Table     sql.Node
   226  	column    *sql.Column
   227  	order     *sql.ColumnOrder
   228  	targetSch sql.Schema
   229  }
   230  
   231  var _ sql.Node = (*AddColumn)(nil)
   232  var _ sql.Expressioner = (*AddColumn)(nil)
   233  var _ sql.SchemaTarget = (*AddColumn)(nil)
   234  var _ sql.CollationCoercible = (*AddColumn)(nil)
   235  
   236  func (a *AddColumn) DebugString() string {
   237  	pr := sql.NewTreePrinter()
   238  	pr.WriteNode("add column %s to %s", a.column.Name, a.Table)
   239  
   240  	var children []string
   241  	children = append(children, sql.DebugString(a.column))
   242  	for _, col := range a.targetSch {
   243  		children = append(children, sql.DebugString(col))
   244  	}
   245  
   246  	pr.WriteChildren(children...)
   247  	return pr.String()
   248  }
   249  
   250  func NewAddColumnResolved(table *ResolvedTable, column sql.Column, order *sql.ColumnOrder) *AddColumn {
   251  	column.Source = table.Name()
   252  	return &AddColumn{
   253  		ddlNode: ddlNode{Db: table.SqlDatabase},
   254  		Table:   table,
   255  		column:  &column,
   256  		order:   order,
   257  	}
   258  }
   259  
   260  func NewAddColumn(database sql.Database, table *UnresolvedTable, column *sql.Column, order *sql.ColumnOrder) *AddColumn {
   261  	column.Source = table.name
   262  	return &AddColumn{
   263  		ddlNode: ddlNode{Db: database},
   264  		Table:   table,
   265  		column:  column,
   266  		order:   order,
   267  	}
   268  }
   269  
   270  func (a *AddColumn) Column() *sql.Column {
   271  	return a.column
   272  }
   273  
   274  func (a *AddColumn) Order() *sql.ColumnOrder {
   275  	return a.order
   276  }
   277  
   278  func (a *AddColumn) IsReadOnly() bool {
   279  	return false
   280  }
   281  
   282  func (a *AddColumn) WithDatabase(db sql.Database) (sql.Node, error) {
   283  	na := *a
   284  	na.Db = db
   285  	return &na, nil
   286  }
   287  
   288  // Schema implements the sql.Node interface.
   289  func (a *AddColumn) Schema() sql.Schema {
   290  	return types.OkResultSchema
   291  }
   292  
   293  func (a *AddColumn) String() string {
   294  	return fmt.Sprintf("add column %s", a.column.Name)
   295  }
   296  
   297  func (a *AddColumn) Expressions() []sql.Expression {
   298  	return append(transform.WrappedColumnDefaults(a.targetSch), transform.WrappedColumnDefaults(sql.Schema{a.column})...)
   299  }
   300  
   301  func (a AddColumn) WithExpressions(exprs ...sql.Expression) (sql.Node, error) {
   302  	if len(exprs) != 1+len(a.targetSch) {
   303  		return nil, sql.ErrInvalidChildrenNumber.New(a, len(exprs), 1+len(a.targetSch))
   304  	}
   305  
   306  	sch, err := transform.SchemaWithDefaults(a.targetSch, exprs[:len(a.targetSch)])
   307  	if err != nil {
   308  		return nil, err
   309  	}
   310  
   311  	a.targetSch = sch
   312  
   313  	colSchema := sql.Schema{a.column}
   314  	colSchema, err = transform.SchemaWithDefaults(colSchema, exprs[len(exprs)-1:])
   315  	if err != nil {
   316  		return nil, err
   317  	}
   318  
   319  	// *sql.Column is a reference type, make a copy before we modify it so we don't affect the original node
   320  	a.column = colSchema[0]
   321  	return &a, nil
   322  }
   323  
   324  // Resolved implements the Resolvable interface.
   325  func (a *AddColumn) Resolved() bool {
   326  	return a.ddlNode.Resolved() && a.Table.Resolved() && a.column.Default.Resolved() && a.targetSch.Resolved()
   327  }
   328  
   329  // WithTargetSchema implements sql.SchemaTarget
   330  func (a AddColumn) WithTargetSchema(schema sql.Schema) (sql.Node, error) {
   331  	a.targetSch = schema
   332  	return &a, nil
   333  }
   334  
   335  func (a *AddColumn) TargetSchema() sql.Schema {
   336  	return a.targetSch
   337  }
   338  
   339  func (a *AddColumn) ValidateDefaultPosition(tblSch sql.Schema) error {
   340  	colsAfterThis := map[string]*sql.Column{a.column.Name: a.column}
   341  	if a.order != nil {
   342  		if a.order.First {
   343  			for i := 0; i < len(tblSch); i++ {
   344  				colsAfterThis[tblSch[i].Name] = tblSch[i]
   345  			}
   346  		} else {
   347  			i := 1
   348  			for ; i < len(tblSch); i++ {
   349  				if tblSch[i-1].Name == a.order.AfterColumn {
   350  					break
   351  				}
   352  			}
   353  			for ; i < len(tblSch); i++ {
   354  				colsAfterThis[tblSch[i].Name] = tblSch[i]
   355  			}
   356  		}
   357  	}
   358  
   359  	err := inspectDefaultForInvalidColumns(a.column, colsAfterThis)
   360  	if err != nil {
   361  		return err
   362  	}
   363  
   364  	return nil
   365  }
   366  
   367  func inspectDefaultForInvalidColumns(col *sql.Column, columnsAfterThis map[string]*sql.Column) error {
   368  	if col.Default == nil {
   369  		return nil
   370  	}
   371  	var err error
   372  	sql.Inspect(col.Default, func(expr sql.Expression) bool {
   373  		switch expr := expr.(type) {
   374  		case *expression.GetField:
   375  			if col, ok := columnsAfterThis[expr.Name()]; ok && col.Default != nil && !col.Default.IsLiteral() {
   376  				err = sql.ErrInvalidDefaultValueOrder.New(col.Name)
   377  				return false
   378  			}
   379  		}
   380  		return true
   381  	})
   382  	return err
   383  }
   384  
   385  func (a AddColumn) WithChildren(children ...sql.Node) (sql.Node, error) {
   386  	if len(children) != 1 {
   387  		return nil, sql.ErrInvalidChildrenNumber.New(a, len(children), 1)
   388  	}
   389  	a.Table = children[0]
   390  	return &a, nil
   391  }
   392  
   393  // CheckPrivileges implements the interface sql.Node.
   394  func (a *AddColumn) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool {
   395  	subject := sql.PrivilegeCheckSubject{
   396  		Database: CheckPrivilegeNameForDatabase(a.Db),
   397  		Table:    getTableName(a.Table),
   398  	}
   399  	return opChecker.UserHasPrivileges(ctx,
   400  		sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Alter))
   401  }
   402  
   403  // CollationCoercibility implements the interface sql.CollationCoercible.
   404  func (*AddColumn) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   405  	return sql.Collation_binary, 7
   406  }
   407  
   408  func (a *AddColumn) Children() []sql.Node {
   409  	return []sql.Node{a.Table}
   410  }
   411  
   412  // colDefault expression evaluates the column default for a row being inserted, correctly handling zero values and
   413  // nulls
   414  type ColDefaultExpression struct {
   415  	Column *sql.Column
   416  }
   417  
   418  var _ sql.Expression = ColDefaultExpression{}
   419  var _ sql.CollationCoercible = ColDefaultExpression{}
   420  
   421  func (c ColDefaultExpression) Resolved() bool   { return true }
   422  func (c ColDefaultExpression) String() string   { return "" }
   423  func (c ColDefaultExpression) Type() sql.Type   { return c.Column.Type }
   424  func (c ColDefaultExpression) IsNullable() bool { return c.Column.Default == nil }
   425  func (c ColDefaultExpression) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   426  	if c.Column != nil && c.Column.Default != nil {
   427  		return c.Column.Default.CollationCoercibility(ctx)
   428  	}
   429  	return sql.Collation_binary, 6
   430  }
   431  
   432  func (c ColDefaultExpression) Children() []sql.Expression {
   433  	panic("ColDefaultExpression is only meant for immediate evaluation and should never be modified")
   434  }
   435  
   436  func (c ColDefaultExpression) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   437  	panic("ColDefaultExpression is only meant for immediate evaluation and should never be modified")
   438  }
   439  
   440  func (c ColDefaultExpression) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
   441  	columnDefaultExpr := c.Column.Default
   442  	if columnDefaultExpr == nil {
   443  		columnDefaultExpr = c.Column.Generated
   444  	}
   445  
   446  	if columnDefaultExpr == nil && !c.Column.Nullable {
   447  		val := c.Column.Type.Zero()
   448  		ret, _, err := c.Column.Type.Convert(val)
   449  		return ret, err
   450  	} else if columnDefaultExpr != nil {
   451  		val, err := columnDefaultExpr.Eval(ctx, row)
   452  		if err != nil {
   453  			return nil, err
   454  		}
   455  		ret, _, err := c.Column.Type.Convert(val)
   456  		return ret, err
   457  	}
   458  
   459  	return nil, nil
   460  }
   461  
   462  type DropColumn struct {
   463  	ddlNode
   464  	Table        sql.Node
   465  	Column       string
   466  	checks       sql.CheckConstraints
   467  	targetSchema sql.Schema
   468  }
   469  
   470  var _ sql.Node = (*DropColumn)(nil)
   471  var _ sql.Databaser = (*DropColumn)(nil)
   472  var _ sql.SchemaTarget = (*DropColumn)(nil)
   473  var _ sql.CheckConstraintNode = (*DropColumn)(nil)
   474  var _ sql.CollationCoercible = (*DropColumn)(nil)
   475  
   476  func NewDropColumnResolved(table *ResolvedTable, column string) *DropColumn {
   477  	return &DropColumn{
   478  		ddlNode: ddlNode{Db: table.SqlDatabase},
   479  		Table:   table,
   480  		Column:  column,
   481  	}
   482  }
   483  
   484  func NewDropColumn(database sql.Database, table *UnresolvedTable, column string) *DropColumn {
   485  	return &DropColumn{
   486  		ddlNode: ddlNode{Db: database},
   487  		Table:   table,
   488  		Column:  column,
   489  	}
   490  }
   491  
   492  func (d *DropColumn) Checks() sql.CheckConstraints {
   493  	return d.checks
   494  }
   495  
   496  func (d *DropColumn) WithChecks(checks sql.CheckConstraints) sql.Node {
   497  	ret := *d
   498  	ret.checks = checks
   499  	return &ret
   500  }
   501  
   502  func (d *DropColumn) WithDatabase(db sql.Database) (sql.Node, error) {
   503  	nd := *d
   504  	nd.Db = db
   505  	return &nd, nil
   506  }
   507  
   508  func (d *DropColumn) String() string {
   509  	return fmt.Sprintf("drop column %s", d.Column)
   510  }
   511  
   512  func (d *DropColumn) IsReadOnly() bool {
   513  	return false
   514  }
   515  
   516  // Validate returns an error if this drop column operation is invalid (because it would invalidate a column default
   517  // or other constraint).
   518  // TODO: move this check to analyzer
   519  func (d *DropColumn) Validate(ctx *sql.Context, tbl sql.Table) error {
   520  	colIdx := d.targetSchema.IndexOfColName(d.Column)
   521  	if colIdx == -1 {
   522  		return sql.ErrTableColumnNotFound.New(tbl.Name(), d.Column)
   523  	}
   524  
   525  	for _, col := range d.targetSchema {
   526  		if col.Default == nil {
   527  			continue
   528  		}
   529  		var err error
   530  		sql.Inspect(col.Default, func(expr sql.Expression) bool {
   531  			switch expr := expr.(type) {
   532  			case *expression.GetField:
   533  				if expr.Name() == d.Column {
   534  					err = sql.ErrDropColumnReferencedInDefault.New(d.Column, expr.Name())
   535  					return false
   536  				}
   537  			}
   538  			return true
   539  		})
   540  		if err != nil {
   541  			return err
   542  		}
   543  	}
   544  
   545  	if fkTable, ok := tbl.(sql.ForeignKeyTable); ok {
   546  		lowercaseColumn := strings.ToLower(d.Column)
   547  		fks, err := fkTable.GetDeclaredForeignKeys(ctx)
   548  		if err != nil {
   549  			return err
   550  		}
   551  		for _, fk := range fks {
   552  			for _, fkCol := range fk.Columns {
   553  				if lowercaseColumn == strings.ToLower(fkCol) {
   554  					return sql.ErrForeignKeyDropColumn.New(d.Column, fk.Name)
   555  				}
   556  			}
   557  		}
   558  		parentFks, err := fkTable.GetReferencedForeignKeys(ctx)
   559  		if err != nil {
   560  			return err
   561  		}
   562  		for _, parentFk := range parentFks {
   563  			for _, parentFkCol := range parentFk.Columns {
   564  				if lowercaseColumn == strings.ToLower(parentFkCol) {
   565  					return sql.ErrForeignKeyDropColumn.New(d.Column, parentFk.Name)
   566  				}
   567  			}
   568  		}
   569  	}
   570  
   571  	return nil
   572  }
   573  
   574  func (d *DropColumn) Schema() sql.Schema {
   575  	return types.OkResultSchema
   576  }
   577  
   578  func (d *DropColumn) Resolved() bool {
   579  	return d.Table.Resolved() && d.ddlNode.Resolved() && d.targetSchema.Resolved()
   580  }
   581  
   582  func (d *DropColumn) Children() []sql.Node {
   583  	return []sql.Node{d.Table}
   584  }
   585  
   586  func (d DropColumn) WithChildren(children ...sql.Node) (sql.Node, error) {
   587  	if len(children) != 1 {
   588  		return nil, sql.ErrInvalidChildrenNumber.New(d, len(children), 1)
   589  	}
   590  	d.Table = children[0]
   591  	return &d, nil
   592  }
   593  
   594  // CheckPrivileges implements the interface sql.Node.
   595  func (d *DropColumn) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool {
   596  	subject := sql.PrivilegeCheckSubject{
   597  		Database: CheckPrivilegeNameForDatabase(d.Db),
   598  		Table:    getTableName(d.Table),
   599  	}
   600  	return opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Alter))
   601  }
   602  
   603  // CollationCoercibility implements the interface sql.CollationCoercible.
   604  func (*DropColumn) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   605  	return sql.Collation_binary, 7
   606  }
   607  
   608  func (d DropColumn) WithTargetSchema(schema sql.Schema) (sql.Node, error) {
   609  	d.targetSchema = schema
   610  	return &d, nil
   611  }
   612  
   613  func (d *DropColumn) TargetSchema() sql.Schema {
   614  	return d.targetSchema
   615  }
   616  
   617  func (d *DropColumn) Expressions() []sql.Expression {
   618  	return transform.WrappedColumnDefaults(d.targetSchema)
   619  }
   620  
   621  func (d DropColumn) WithExpressions(exprs ...sql.Expression) (sql.Node, error) {
   622  	if len(exprs) != len(d.targetSchema) {
   623  		return nil, sql.ErrInvalidChildrenNumber.New(d, len(exprs), len(d.targetSchema))
   624  	}
   625  
   626  	sch, err := transform.SchemaWithDefaults(d.targetSchema, exprs)
   627  	if err != nil {
   628  		return nil, err
   629  	}
   630  	d.targetSchema = sch
   631  
   632  	return &d, nil
   633  }
   634  
   635  type RenameColumn struct {
   636  	ddlNode
   637  	Table         sql.Node
   638  	ColumnName    string
   639  	NewColumnName string
   640  	checks        sql.CheckConstraints
   641  	targetSchema  sql.Schema
   642  }
   643  
   644  var _ sql.Node = (*RenameColumn)(nil)
   645  var _ sql.Databaser = (*RenameColumn)(nil)
   646  var _ sql.SchemaTarget = (*RenameColumn)(nil)
   647  var _ sql.CheckConstraintNode = (*RenameColumn)(nil)
   648  var _ sql.CollationCoercible = (*RenameColumn)(nil)
   649  
   650  func NewRenameColumnResolved(table *ResolvedTable, columnName string, newColumnName string) *RenameColumn {
   651  	return &RenameColumn{
   652  		ddlNode:       ddlNode{Db: table.SqlDatabase},
   653  		Table:         table,
   654  		ColumnName:    columnName,
   655  		NewColumnName: newColumnName,
   656  	}
   657  }
   658  
   659  func NewRenameColumn(database sql.Database, table *UnresolvedTable, columnName string, newColumnName string) *RenameColumn {
   660  	return &RenameColumn{
   661  		ddlNode:       ddlNode{Db: database},
   662  		Table:         table,
   663  		ColumnName:    columnName,
   664  		NewColumnName: newColumnName,
   665  	}
   666  }
   667  
   668  func (r *RenameColumn) Checks() sql.CheckConstraints {
   669  	return r.checks
   670  }
   671  
   672  func (r *RenameColumn) WithChecks(checks sql.CheckConstraints) sql.Node {
   673  	ret := *r
   674  	ret.checks = checks
   675  	return &ret
   676  }
   677  
   678  func (r *RenameColumn) WithDatabase(db sql.Database) (sql.Node, error) {
   679  	nr := *r
   680  	nr.Db = db
   681  	return &nr, nil
   682  }
   683  
   684  func (r RenameColumn) WithTargetSchema(schema sql.Schema) (sql.Node, error) {
   685  	r.targetSchema = schema
   686  	return &r, nil
   687  }
   688  
   689  func (r *RenameColumn) TargetSchema() sql.Schema {
   690  	return r.targetSchema
   691  }
   692  
   693  func (r *RenameColumn) String() string {
   694  	return fmt.Sprintf("rename column %s to %s", r.ColumnName, r.NewColumnName)
   695  }
   696  
   697  func (r *RenameColumn) IsReadOnly() bool {
   698  	return false
   699  }
   700  
   701  func (r *RenameColumn) DebugString() string {
   702  	pr := sql.NewTreePrinter()
   703  	pr.WriteNode("rename column %s to %s", r.ColumnName, r.NewColumnName)
   704  
   705  	var children []string
   706  	for _, col := range r.targetSchema {
   707  		children = append(children, sql.DebugString(col))
   708  	}
   709  
   710  	pr.WriteChildren(children...)
   711  	return pr.String()
   712  }
   713  
   714  func (r *RenameColumn) Resolved() bool {
   715  	return r.Table.Resolved() && r.ddlNode.Resolved() && r.targetSchema.Resolved()
   716  }
   717  
   718  func (r *RenameColumn) Schema() sql.Schema {
   719  	return types.OkResultSchema
   720  }
   721  
   722  func (r *RenameColumn) Expressions() []sql.Expression {
   723  	return transform.WrappedColumnDefaults(r.targetSchema)
   724  }
   725  
   726  func (r RenameColumn) WithExpressions(exprs ...sql.Expression) (sql.Node, error) {
   727  	if len(exprs) != len(r.targetSchema) {
   728  		return nil, sql.ErrInvalidChildrenNumber.New(r, len(exprs), len(r.targetSchema))
   729  	}
   730  
   731  	sch, err := transform.SchemaWithDefaults(r.targetSchema, exprs)
   732  	if err != nil {
   733  		return nil, err
   734  	}
   735  
   736  	r.targetSchema = sch
   737  	return &r, nil
   738  }
   739  
   740  func (r *RenameColumn) Children() []sql.Node {
   741  	return []sql.Node{r.Table}
   742  }
   743  
   744  func (r RenameColumn) WithChildren(children ...sql.Node) (sql.Node, error) {
   745  	if len(children) != 1 {
   746  		return nil, sql.ErrInvalidChildrenNumber.New(r, len(children), 1)
   747  	}
   748  	r.Table = children[0]
   749  	return &r, nil
   750  }
   751  
   752  // CheckPrivileges implements the interface sql.Node.
   753  func (r *RenameColumn) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool {
   754  	subject := sql.PrivilegeCheckSubject{
   755  		Database: CheckPrivilegeNameForDatabase(r.Db),
   756  		Table:    getTableName(r.Table),
   757  	}
   758  
   759  	return opChecker.UserHasPrivileges(ctx,
   760  		sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Alter))
   761  }
   762  
   763  // CollationCoercibility implements the interface sql.CollationCoercible.
   764  func (*RenameColumn) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   765  	return sql.Collation_binary, 7
   766  }
   767  
   768  type ModifyColumn struct {
   769  	ddlNode
   770  	Table        sql.Node
   771  	columnName   string
   772  	column       *sql.Column
   773  	order        *sql.ColumnOrder
   774  	targetSchema sql.Schema
   775  }
   776  
   777  var _ sql.Node = (*ModifyColumn)(nil)
   778  var _ sql.Expressioner = (*ModifyColumn)(nil)
   779  var _ sql.Databaser = (*ModifyColumn)(nil)
   780  var _ sql.SchemaTarget = (*ModifyColumn)(nil)
   781  var _ sql.CollationCoercible = (*ModifyColumn)(nil)
   782  
   783  func NewModifyColumnResolved(table *ResolvedTable, columnName string, column sql.Column, order *sql.ColumnOrder) *ModifyColumn {
   784  	column.Source = table.Name()
   785  	return &ModifyColumn{
   786  		ddlNode:    ddlNode{Db: table.SqlDatabase},
   787  		Table:      table,
   788  		columnName: columnName,
   789  		column:     &column,
   790  		order:      order,
   791  	}
   792  }
   793  
   794  func NewModifyColumn(database sql.Database, table *UnresolvedTable, columnName string, column *sql.Column, order *sql.ColumnOrder) *ModifyColumn {
   795  	column.Source = table.name
   796  	return &ModifyColumn{
   797  		ddlNode:    ddlNode{Db: database},
   798  		Table:      table,
   799  		columnName: columnName,
   800  		column:     column,
   801  		order:      order,
   802  	}
   803  }
   804  
   805  func (m *ModifyColumn) WithDatabase(db sql.Database) (sql.Node, error) {
   806  	nm := *m
   807  	nm.Db = db
   808  	return &nm, nil
   809  }
   810  
   811  func (m *ModifyColumn) Column() string {
   812  	return m.columnName
   813  }
   814  
   815  func (m *ModifyColumn) NewColumn() *sql.Column {
   816  	return m.column
   817  }
   818  
   819  func (m *ModifyColumn) Order() *sql.ColumnOrder {
   820  	return m.order
   821  }
   822  
   823  // Schema implements the sql.Node interface.
   824  func (m *ModifyColumn) Schema() sql.Schema {
   825  	return types.OkResultSchema
   826  }
   827  
   828  func (m *ModifyColumn) String() string {
   829  	return fmt.Sprintf("modify column %s", m.column.Name)
   830  }
   831  
   832  func (m *ModifyColumn) IsReadOnly() bool {
   833  	return false
   834  }
   835  
   836  func (m ModifyColumn) WithTargetSchema(schema sql.Schema) (sql.Node, error) {
   837  	m.targetSchema = schema
   838  	return &m, nil
   839  }
   840  
   841  func (m *ModifyColumn) TargetSchema() sql.Schema {
   842  	return m.targetSchema
   843  }
   844  
   845  func (m *ModifyColumn) Children() []sql.Node {
   846  	return []sql.Node{m.Table}
   847  }
   848  
   849  func (m ModifyColumn) WithChildren(children ...sql.Node) (sql.Node, error) {
   850  	if len(children) != 1 {
   851  		return nil, sql.ErrInvalidChildrenNumber.New(m, len(children), 1)
   852  	}
   853  	m.Table = children[0]
   854  	return &m, nil
   855  }
   856  
   857  // CheckPrivileges implements the interface sql.Node.
   858  func (m *ModifyColumn) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool {
   859  	subject := sql.PrivilegeCheckSubject{
   860  		Database: CheckPrivilegeNameForDatabase(m.Db),
   861  		Table:    getTableName(m.Table),
   862  	}
   863  	return opChecker.UserHasPrivileges(ctx,
   864  		sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Alter))
   865  }
   866  
   867  // CollationCoercibility implements the interface sql.CollationCoercible.
   868  func (*ModifyColumn) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   869  	return sql.Collation_binary, 7
   870  }
   871  
   872  func (m *ModifyColumn) Expressions() []sql.Expression {
   873  	return append(transform.WrappedColumnDefaults(m.targetSchema), expression.WrapExpressions(m.column.Default)...)
   874  }
   875  
   876  func (m ModifyColumn) WithExpressions(exprs ...sql.Expression) (sql.Node, error) {
   877  	if len(exprs) != 1+len(m.targetSchema) {
   878  		return nil, sql.ErrInvalidChildrenNumber.New(m, len(exprs), 1+len(m.targetSchema))
   879  	}
   880  
   881  	sch, err := transform.SchemaWithDefaults(m.targetSchema, exprs[:len(m.targetSchema)])
   882  	if err != nil {
   883  		return nil, err
   884  	}
   885  	m.targetSchema = sch
   886  
   887  	unwrappedColDefVal, ok := exprs[len(exprs)-1].(*expression.Wrapper).Unwrap().(*sql.ColumnDefaultValue)
   888  	if ok {
   889  		m.column.Default = unwrappedColDefVal
   890  	} else { // nil fails type check
   891  		m.column.Default = nil
   892  	}
   893  	return &m, nil
   894  }
   895  
   896  // Resolved implements the Resolvable interface.
   897  func (m *ModifyColumn) Resolved() bool {
   898  	return m.Table.Resolved() && m.column.Default.Resolved() && m.ddlNode.Resolved() && m.targetSchema.Resolved()
   899  }
   900  
   901  func (m *ModifyColumn) ValidateDefaultPosition(tblSch sql.Schema) error {
   902  	colsBeforeThis := make(map[string]*sql.Column)
   903  	colsAfterThis := make(map[string]*sql.Column) // includes the modified column
   904  	if m.order == nil {
   905  		i := 0
   906  		for ; i < len(tblSch); i++ {
   907  			if tblSch[i].Name == m.column.Name {
   908  				colsAfterThis[m.column.Name] = m.column
   909  				break
   910  			}
   911  			colsBeforeThis[tblSch[i].Name] = tblSch[i]
   912  		}
   913  		for ; i < len(tblSch); i++ {
   914  			colsAfterThis[tblSch[i].Name] = tblSch[i]
   915  		}
   916  	} else if m.order.First {
   917  		for i := 0; i < len(tblSch); i++ {
   918  			colsAfterThis[tblSch[i].Name] = tblSch[i]
   919  		}
   920  	} else {
   921  		i := 1
   922  		for ; i < len(tblSch); i++ {
   923  			colsBeforeThis[tblSch[i].Name] = tblSch[i]
   924  			if tblSch[i-1].Name == m.order.AfterColumn {
   925  				break
   926  			}
   927  		}
   928  		for ; i < len(tblSch); i++ {
   929  			colsAfterThis[tblSch[i].Name] = tblSch[i]
   930  		}
   931  		delete(colsBeforeThis, m.column.Name)
   932  		colsAfterThis[m.column.Name] = m.column
   933  	}
   934  
   935  	err := inspectDefaultForInvalidColumns(m.column, colsAfterThis)
   936  	if err != nil {
   937  		return err
   938  	}
   939  	thisCol := map[string]*sql.Column{m.column.Name: m.column}
   940  	for _, colBefore := range colsBeforeThis {
   941  		err = inspectDefaultForInvalidColumns(colBefore, thisCol)
   942  		if err != nil {
   943  			return err
   944  		}
   945  	}
   946  
   947  	return nil
   948  }
   949  
   950  type AlterTableCollation struct {
   951  	ddlNode
   952  	Table     sql.Node
   953  	Collation sql.CollationID
   954  }
   955  
   956  var _ sql.Node = (*AlterTableCollation)(nil)
   957  var _ sql.Databaser = (*AlterTableCollation)(nil)
   958  
   959  // NewAlterTableCollationResolved returns a new *AlterTableCollation
   960  func NewAlterTableCollationResolved(table *ResolvedTable, collation sql.CollationID) *AlterTableCollation {
   961  	return &AlterTableCollation{
   962  		ddlNode:   ddlNode{Db: table.SqlDatabase},
   963  		Table:     table,
   964  		Collation: collation,
   965  	}
   966  }
   967  
   968  // NewAlterTableCollation returns a new *AlterTableCollation
   969  func NewAlterTableCollation(database sql.Database, table *UnresolvedTable, collation sql.CollationID) *AlterTableCollation {
   970  	return &AlterTableCollation{
   971  		ddlNode:   ddlNode{Db: database},
   972  		Table:     table,
   973  		Collation: collation,
   974  	}
   975  }
   976  
   977  // WithDatabase implements the interface sql.Databaser.
   978  func (atc *AlterTableCollation) WithDatabase(db sql.Database) (sql.Node, error) {
   979  	natc := *atc
   980  	natc.Db = db
   981  	return &natc, nil
   982  }
   983  
   984  func (atc *AlterTableCollation) IsReadOnly() bool {
   985  	return false
   986  }
   987  
   988  // String implements the interface sql.Node.
   989  func (atc *AlterTableCollation) String() string {
   990  	return fmt.Sprintf("alter table %s collate %s", atc.Table.String(), atc.Collation.Name())
   991  }
   992  
   993  // DebugString implements the interface sql.Node.
   994  func (atc *AlterTableCollation) DebugString() string {
   995  	return atc.String()
   996  }
   997  
   998  // Resolved implements the interface sql.Node.
   999  func (atc *AlterTableCollation) Resolved() bool {
  1000  	return atc.Table.Resolved() && atc.ddlNode.Resolved()
  1001  }
  1002  
  1003  // Schema implements the interface sql.Node.
  1004  func (atc *AlterTableCollation) Schema() sql.Schema {
  1005  	return types.OkResultSchema
  1006  }
  1007  
  1008  // Children implements the interface sql.Node.
  1009  func (atc *AlterTableCollation) Children() []sql.Node {
  1010  	return []sql.Node{atc.Table}
  1011  }
  1012  
  1013  // WithChildren implements the interface sql.Node.
  1014  func (atc *AlterTableCollation) WithChildren(children ...sql.Node) (sql.Node, error) {
  1015  	if len(children) != 1 {
  1016  		return nil, sql.ErrInvalidChildrenNumber.New(atc, len(children), 1)
  1017  	}
  1018  	natc := *atc
  1019  	natc.Table = children[0]
  1020  	return &natc, nil
  1021  }
  1022  
  1023  // CheckPrivileges implements the interface sql.Node.
  1024  func (atc *AlterTableCollation) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool {
  1025  	subject := sql.PrivilegeCheckSubject{
  1026  		Database: CheckPrivilegeNameForDatabase(atc.Db),
  1027  		Table:    getTableName(atc.Table),
  1028  	}
  1029  
  1030  	return opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Alter))
  1031  }