github.com/dolthub/go-mysql-server@v0.18.0/sql/plan/alter_table.go (about) 1 // Copyright 2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package plan 16 17 import ( 18 "fmt" 19 "strings" 20 21 "github.com/dolthub/go-mysql-server/sql" 22 "github.com/dolthub/go-mysql-server/sql/expression" 23 "github.com/dolthub/go-mysql-server/sql/transform" 24 "github.com/dolthub/go-mysql-server/sql/types" 25 ) 26 27 type RenameTable struct { 28 ddlNode 29 OldNames []string 30 NewNames []string 31 alterTblDef bool 32 } 33 34 var _ sql.Node = (*RenameTable)(nil) 35 var _ sql.Databaser = (*RenameTable)(nil) 36 var _ sql.CollationCoercible = (*RenameTable)(nil) 37 38 // NewRenameTable creates a new RenameTable node 39 func NewRenameTable(db sql.Database, oldNames, newNames []string, alterTbl bool) *RenameTable { 40 return &RenameTable{ 41 ddlNode: ddlNode{db}, 42 OldNames: oldNames, 43 NewNames: newNames, 44 alterTblDef: alterTbl, 45 } 46 } 47 48 func (r *RenameTable) WithDatabase(db sql.Database) (sql.Node, error) { 49 nr := *r 50 nr.Db = db 51 return &nr, nil 52 } 53 54 func (r *RenameTable) String() string { 55 return fmt.Sprintf("Rename table %s to %s", r.OldNames, r.NewNames) 56 } 57 58 func (r *RenameTable) IsReadOnly() bool { 59 return false 60 } 61 62 func (r *RenameTable) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error) { 63 renamer, _ := r.Db.(sql.TableRenamer) 64 viewDb, _ := r.Db.(sql.ViewDatabase) 65 viewRegistry := ctx.GetViewRegistry() 66 67 for i, oldName := range r.OldNames { 68 if tbl, exists := r.tableExists(ctx, oldName); exists { 69 err := r.renameTable(ctx, renamer, tbl, oldName, r.NewNames[i]) 70 if err != nil { 71 return nil, err 72 } 73 } else { 74 success, err := r.renameView(ctx, viewDb, viewRegistry, oldName, r.NewNames[i]) 75 if err != nil { 76 return nil, err 77 } else if !success { 78 return nil, sql.ErrTableNotFound.New(oldName) 79 } 80 } 81 } 82 83 return sql.RowsToRowIter(sql.NewRow(types.NewOkResult(0))), nil 84 } 85 86 func (r *RenameTable) WithChildren(children ...sql.Node) (sql.Node, error) { 87 return NillaryWithChildren(r, children...) 88 } 89 90 // CheckPrivileges implements the interface sql.Node. 91 func (r *RenameTable) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { 92 var operations []sql.PrivilegedOperation 93 for _, oldName := range r.OldNames { 94 subject := sql.PrivilegeCheckSubject{ 95 Database: CheckPrivilegeNameForDatabase(r.Db), 96 Table: oldName, 97 } 98 operations = append(operations, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Alter, sql.PrivilegeType_Drop)) 99 } 100 for _, newName := range r.NewNames { 101 subject := sql.PrivilegeCheckSubject{ 102 Database: CheckPrivilegeNameForDatabase(r.Db), 103 Table: newName, 104 } 105 operations = append(operations, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Create, sql.PrivilegeType_Insert)) 106 } 107 return opChecker.UserHasPrivileges(ctx, operations...) 108 } 109 110 // CollationCoercibility implements the interface sql.CollationCoercible. 111 func (*RenameTable) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { 112 return sql.Collation_binary, 7 113 } 114 115 func (r *RenameTable) tableExists(ctx *sql.Context, name string) (sql.Table, bool) { 116 tbl, ok, err := r.Db.GetTableInsensitive(ctx, name) 117 if err != nil || !ok { 118 return nil, false 119 } 120 return tbl, true 121 } 122 123 func (r *RenameTable) renameTable(ctx *sql.Context, renamer sql.TableRenamer, tbl sql.Table, oldName, newName string) error { 124 if renamer == nil { 125 return sql.ErrRenameTableNotSupported.New(r.Db.Name()) 126 } 127 128 if fkTable, ok := tbl.(sql.ForeignKeyTable); ok { 129 parentFks, err := fkTable.GetReferencedForeignKeys(ctx) 130 if err != nil { 131 return err 132 } 133 for _, parentFk := range parentFks { 134 //TODO: support renaming tables across databases for foreign keys 135 if strings.ToLower(parentFk.Database) != strings.ToLower(parentFk.ParentDatabase) { 136 return fmt.Errorf("updating foreign key table names across databases is not yet supported") 137 } 138 parentFk.ParentTable = newName 139 childTbl, ok, err := r.Db.GetTableInsensitive(ctx, parentFk.Table) 140 if err != nil { 141 return err 142 } 143 if !ok { 144 return sql.ErrTableNotFound.New(parentFk.Table) 145 } 146 childFkTbl, ok := childTbl.(sql.ForeignKeyTable) 147 if !ok { 148 return fmt.Errorf("referenced table `%s` supports foreign keys but declaring table `%s` does not", parentFk.ParentTable, parentFk.Table) 149 } 150 err = childFkTbl.UpdateForeignKey(ctx, parentFk.Name, parentFk) 151 if err != nil { 152 return err 153 } 154 } 155 156 fks, err := fkTable.GetDeclaredForeignKeys(ctx) 157 if err != nil { 158 return err 159 } 160 for _, fk := range fks { 161 fk.Table = newName 162 err = fkTable.UpdateForeignKey(ctx, fk.Name, fk) 163 if err != nil { 164 return err 165 } 166 } 167 } 168 169 err := renamer.RenameTable(ctx, oldName, newName) 170 if err != nil { 171 return err 172 } 173 174 return nil 175 } 176 177 func (r *RenameTable) renameView(ctx *sql.Context, viewDb sql.ViewDatabase, vr *sql.ViewRegistry, oldName, newName string) (bool, error) { 178 if viewDb != nil { 179 oldView, exists, err := viewDb.GetViewDefinition(ctx, oldName) 180 if err != nil { 181 return false, err 182 } else if !exists { 183 return false, nil 184 } 185 186 if r.alterTblDef { 187 return false, sql.ErrExpectedTableFoundView.New(fmt.Sprintf("'%s.%s'", r.Db.Name(), oldName)) 188 } 189 190 err = viewDb.DropView(ctx, oldName) 191 if err != nil { 192 return false, err 193 } 194 195 err = viewDb.CreateView(ctx, newName, oldView.TextDefinition, oldView.CreateViewStatement) 196 if err != nil { 197 return false, err 198 } 199 200 return true, nil 201 } else { 202 view, exists := vr.View(r.Db.Name(), oldName) 203 if !exists { 204 return false, nil 205 } 206 207 if r.alterTblDef { 208 return false, sql.ErrExpectedTableFoundView.New(fmt.Sprintf("'%s.%s'", r.Db.Name(), oldName)) 209 } 210 211 err := vr.Delete(r.Db.Name(), oldName) 212 if err != nil { 213 return false, nil 214 } 215 err = vr.Register(r.Db.Name(), sql.NewView(newName, view.Definition(), view.TextDefinition(), view.CreateStatement())) 216 if err != nil { 217 return false, nil 218 } 219 return true, nil 220 } 221 } 222 223 type AddColumn struct { 224 ddlNode 225 Table sql.Node 226 column *sql.Column 227 order *sql.ColumnOrder 228 targetSch sql.Schema 229 } 230 231 var _ sql.Node = (*AddColumn)(nil) 232 var _ sql.Expressioner = (*AddColumn)(nil) 233 var _ sql.SchemaTarget = (*AddColumn)(nil) 234 var _ sql.CollationCoercible = (*AddColumn)(nil) 235 236 func (a *AddColumn) DebugString() string { 237 pr := sql.NewTreePrinter() 238 pr.WriteNode("add column %s to %s", a.column.Name, a.Table) 239 240 var children []string 241 children = append(children, sql.DebugString(a.column)) 242 for _, col := range a.targetSch { 243 children = append(children, sql.DebugString(col)) 244 } 245 246 pr.WriteChildren(children...) 247 return pr.String() 248 } 249 250 func NewAddColumnResolved(table *ResolvedTable, column sql.Column, order *sql.ColumnOrder) *AddColumn { 251 column.Source = table.Name() 252 return &AddColumn{ 253 ddlNode: ddlNode{Db: table.SqlDatabase}, 254 Table: table, 255 column: &column, 256 order: order, 257 } 258 } 259 260 func NewAddColumn(database sql.Database, table *UnresolvedTable, column *sql.Column, order *sql.ColumnOrder) *AddColumn { 261 column.Source = table.name 262 return &AddColumn{ 263 ddlNode: ddlNode{Db: database}, 264 Table: table, 265 column: column, 266 order: order, 267 } 268 } 269 270 func (a *AddColumn) Column() *sql.Column { 271 return a.column 272 } 273 274 func (a *AddColumn) Order() *sql.ColumnOrder { 275 return a.order 276 } 277 278 func (a *AddColumn) IsReadOnly() bool { 279 return false 280 } 281 282 func (a *AddColumn) WithDatabase(db sql.Database) (sql.Node, error) { 283 na := *a 284 na.Db = db 285 return &na, nil 286 } 287 288 // Schema implements the sql.Node interface. 289 func (a *AddColumn) Schema() sql.Schema { 290 return types.OkResultSchema 291 } 292 293 func (a *AddColumn) String() string { 294 return fmt.Sprintf("add column %s", a.column.Name) 295 } 296 297 func (a *AddColumn) Expressions() []sql.Expression { 298 return append(transform.WrappedColumnDefaults(a.targetSch), transform.WrappedColumnDefaults(sql.Schema{a.column})...) 299 } 300 301 func (a AddColumn) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { 302 if len(exprs) != 1+len(a.targetSch) { 303 return nil, sql.ErrInvalidChildrenNumber.New(a, len(exprs), 1+len(a.targetSch)) 304 } 305 306 sch, err := transform.SchemaWithDefaults(a.targetSch, exprs[:len(a.targetSch)]) 307 if err != nil { 308 return nil, err 309 } 310 311 a.targetSch = sch 312 313 colSchema := sql.Schema{a.column} 314 colSchema, err = transform.SchemaWithDefaults(colSchema, exprs[len(exprs)-1:]) 315 if err != nil { 316 return nil, err 317 } 318 319 // *sql.Column is a reference type, make a copy before we modify it so we don't affect the original node 320 a.column = colSchema[0] 321 return &a, nil 322 } 323 324 // Resolved implements the Resolvable interface. 325 func (a *AddColumn) Resolved() bool { 326 return a.ddlNode.Resolved() && a.Table.Resolved() && a.column.Default.Resolved() && a.targetSch.Resolved() 327 } 328 329 // WithTargetSchema implements sql.SchemaTarget 330 func (a AddColumn) WithTargetSchema(schema sql.Schema) (sql.Node, error) { 331 a.targetSch = schema 332 return &a, nil 333 } 334 335 func (a *AddColumn) TargetSchema() sql.Schema { 336 return a.targetSch 337 } 338 339 func (a *AddColumn) ValidateDefaultPosition(tblSch sql.Schema) error { 340 colsAfterThis := map[string]*sql.Column{a.column.Name: a.column} 341 if a.order != nil { 342 if a.order.First { 343 for i := 0; i < len(tblSch); i++ { 344 colsAfterThis[tblSch[i].Name] = tblSch[i] 345 } 346 } else { 347 i := 1 348 for ; i < len(tblSch); i++ { 349 if tblSch[i-1].Name == a.order.AfterColumn { 350 break 351 } 352 } 353 for ; i < len(tblSch); i++ { 354 colsAfterThis[tblSch[i].Name] = tblSch[i] 355 } 356 } 357 } 358 359 err := inspectDefaultForInvalidColumns(a.column, colsAfterThis) 360 if err != nil { 361 return err 362 } 363 364 return nil 365 } 366 367 func inspectDefaultForInvalidColumns(col *sql.Column, columnsAfterThis map[string]*sql.Column) error { 368 if col.Default == nil { 369 return nil 370 } 371 var err error 372 sql.Inspect(col.Default, func(expr sql.Expression) bool { 373 switch expr := expr.(type) { 374 case *expression.GetField: 375 if col, ok := columnsAfterThis[expr.Name()]; ok && col.Default != nil && !col.Default.IsLiteral() { 376 err = sql.ErrInvalidDefaultValueOrder.New(col.Name) 377 return false 378 } 379 } 380 return true 381 }) 382 return err 383 } 384 385 func (a AddColumn) WithChildren(children ...sql.Node) (sql.Node, error) { 386 if len(children) != 1 { 387 return nil, sql.ErrInvalidChildrenNumber.New(a, len(children), 1) 388 } 389 a.Table = children[0] 390 return &a, nil 391 } 392 393 // CheckPrivileges implements the interface sql.Node. 394 func (a *AddColumn) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { 395 subject := sql.PrivilegeCheckSubject{ 396 Database: CheckPrivilegeNameForDatabase(a.Db), 397 Table: getTableName(a.Table), 398 } 399 return opChecker.UserHasPrivileges(ctx, 400 sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Alter)) 401 } 402 403 // CollationCoercibility implements the interface sql.CollationCoercible. 404 func (*AddColumn) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { 405 return sql.Collation_binary, 7 406 } 407 408 func (a *AddColumn) Children() []sql.Node { 409 return []sql.Node{a.Table} 410 } 411 412 // colDefault expression evaluates the column default for a row being inserted, correctly handling zero values and 413 // nulls 414 type ColDefaultExpression struct { 415 Column *sql.Column 416 } 417 418 var _ sql.Expression = ColDefaultExpression{} 419 var _ sql.CollationCoercible = ColDefaultExpression{} 420 421 func (c ColDefaultExpression) Resolved() bool { return true } 422 func (c ColDefaultExpression) String() string { return "" } 423 func (c ColDefaultExpression) Type() sql.Type { return c.Column.Type } 424 func (c ColDefaultExpression) IsNullable() bool { return c.Column.Default == nil } 425 func (c ColDefaultExpression) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { 426 if c.Column != nil && c.Column.Default != nil { 427 return c.Column.Default.CollationCoercibility(ctx) 428 } 429 return sql.Collation_binary, 6 430 } 431 432 func (c ColDefaultExpression) Children() []sql.Expression { 433 panic("ColDefaultExpression is only meant for immediate evaluation and should never be modified") 434 } 435 436 func (c ColDefaultExpression) WithChildren(children ...sql.Expression) (sql.Expression, error) { 437 panic("ColDefaultExpression is only meant for immediate evaluation and should never be modified") 438 } 439 440 func (c ColDefaultExpression) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { 441 columnDefaultExpr := c.Column.Default 442 if columnDefaultExpr == nil { 443 columnDefaultExpr = c.Column.Generated 444 } 445 446 if columnDefaultExpr == nil && !c.Column.Nullable { 447 val := c.Column.Type.Zero() 448 ret, _, err := c.Column.Type.Convert(val) 449 return ret, err 450 } else if columnDefaultExpr != nil { 451 val, err := columnDefaultExpr.Eval(ctx, row) 452 if err != nil { 453 return nil, err 454 } 455 ret, _, err := c.Column.Type.Convert(val) 456 return ret, err 457 } 458 459 return nil, nil 460 } 461 462 type DropColumn struct { 463 ddlNode 464 Table sql.Node 465 Column string 466 checks sql.CheckConstraints 467 targetSchema sql.Schema 468 } 469 470 var _ sql.Node = (*DropColumn)(nil) 471 var _ sql.Databaser = (*DropColumn)(nil) 472 var _ sql.SchemaTarget = (*DropColumn)(nil) 473 var _ sql.CheckConstraintNode = (*DropColumn)(nil) 474 var _ sql.CollationCoercible = (*DropColumn)(nil) 475 476 func NewDropColumnResolved(table *ResolvedTable, column string) *DropColumn { 477 return &DropColumn{ 478 ddlNode: ddlNode{Db: table.SqlDatabase}, 479 Table: table, 480 Column: column, 481 } 482 } 483 484 func NewDropColumn(database sql.Database, table *UnresolvedTable, column string) *DropColumn { 485 return &DropColumn{ 486 ddlNode: ddlNode{Db: database}, 487 Table: table, 488 Column: column, 489 } 490 } 491 492 func (d *DropColumn) Checks() sql.CheckConstraints { 493 return d.checks 494 } 495 496 func (d *DropColumn) WithChecks(checks sql.CheckConstraints) sql.Node { 497 ret := *d 498 ret.checks = checks 499 return &ret 500 } 501 502 func (d *DropColumn) WithDatabase(db sql.Database) (sql.Node, error) { 503 nd := *d 504 nd.Db = db 505 return &nd, nil 506 } 507 508 func (d *DropColumn) String() string { 509 return fmt.Sprintf("drop column %s", d.Column) 510 } 511 512 func (d *DropColumn) IsReadOnly() bool { 513 return false 514 } 515 516 // Validate returns an error if this drop column operation is invalid (because it would invalidate a column default 517 // or other constraint). 518 // TODO: move this check to analyzer 519 func (d *DropColumn) Validate(ctx *sql.Context, tbl sql.Table) error { 520 colIdx := d.targetSchema.IndexOfColName(d.Column) 521 if colIdx == -1 { 522 return sql.ErrTableColumnNotFound.New(tbl.Name(), d.Column) 523 } 524 525 for _, col := range d.targetSchema { 526 if col.Default == nil { 527 continue 528 } 529 var err error 530 sql.Inspect(col.Default, func(expr sql.Expression) bool { 531 switch expr := expr.(type) { 532 case *expression.GetField: 533 if expr.Name() == d.Column { 534 err = sql.ErrDropColumnReferencedInDefault.New(d.Column, expr.Name()) 535 return false 536 } 537 } 538 return true 539 }) 540 if err != nil { 541 return err 542 } 543 } 544 545 if fkTable, ok := tbl.(sql.ForeignKeyTable); ok { 546 lowercaseColumn := strings.ToLower(d.Column) 547 fks, err := fkTable.GetDeclaredForeignKeys(ctx) 548 if err != nil { 549 return err 550 } 551 for _, fk := range fks { 552 for _, fkCol := range fk.Columns { 553 if lowercaseColumn == strings.ToLower(fkCol) { 554 return sql.ErrForeignKeyDropColumn.New(d.Column, fk.Name) 555 } 556 } 557 } 558 parentFks, err := fkTable.GetReferencedForeignKeys(ctx) 559 if err != nil { 560 return err 561 } 562 for _, parentFk := range parentFks { 563 for _, parentFkCol := range parentFk.Columns { 564 if lowercaseColumn == strings.ToLower(parentFkCol) { 565 return sql.ErrForeignKeyDropColumn.New(d.Column, parentFk.Name) 566 } 567 } 568 } 569 } 570 571 return nil 572 } 573 574 func (d *DropColumn) Schema() sql.Schema { 575 return types.OkResultSchema 576 } 577 578 func (d *DropColumn) Resolved() bool { 579 return d.Table.Resolved() && d.ddlNode.Resolved() && d.targetSchema.Resolved() 580 } 581 582 func (d *DropColumn) Children() []sql.Node { 583 return []sql.Node{d.Table} 584 } 585 586 func (d DropColumn) WithChildren(children ...sql.Node) (sql.Node, error) { 587 if len(children) != 1 { 588 return nil, sql.ErrInvalidChildrenNumber.New(d, len(children), 1) 589 } 590 d.Table = children[0] 591 return &d, nil 592 } 593 594 // CheckPrivileges implements the interface sql.Node. 595 func (d *DropColumn) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { 596 subject := sql.PrivilegeCheckSubject{ 597 Database: CheckPrivilegeNameForDatabase(d.Db), 598 Table: getTableName(d.Table), 599 } 600 return opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Alter)) 601 } 602 603 // CollationCoercibility implements the interface sql.CollationCoercible. 604 func (*DropColumn) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { 605 return sql.Collation_binary, 7 606 } 607 608 func (d DropColumn) WithTargetSchema(schema sql.Schema) (sql.Node, error) { 609 d.targetSchema = schema 610 return &d, nil 611 } 612 613 func (d *DropColumn) TargetSchema() sql.Schema { 614 return d.targetSchema 615 } 616 617 func (d *DropColumn) Expressions() []sql.Expression { 618 return transform.WrappedColumnDefaults(d.targetSchema) 619 } 620 621 func (d DropColumn) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { 622 if len(exprs) != len(d.targetSchema) { 623 return nil, sql.ErrInvalidChildrenNumber.New(d, len(exprs), len(d.targetSchema)) 624 } 625 626 sch, err := transform.SchemaWithDefaults(d.targetSchema, exprs) 627 if err != nil { 628 return nil, err 629 } 630 d.targetSchema = sch 631 632 return &d, nil 633 } 634 635 type RenameColumn struct { 636 ddlNode 637 Table sql.Node 638 ColumnName string 639 NewColumnName string 640 checks sql.CheckConstraints 641 targetSchema sql.Schema 642 } 643 644 var _ sql.Node = (*RenameColumn)(nil) 645 var _ sql.Databaser = (*RenameColumn)(nil) 646 var _ sql.SchemaTarget = (*RenameColumn)(nil) 647 var _ sql.CheckConstraintNode = (*RenameColumn)(nil) 648 var _ sql.CollationCoercible = (*RenameColumn)(nil) 649 650 func NewRenameColumnResolved(table *ResolvedTable, columnName string, newColumnName string) *RenameColumn { 651 return &RenameColumn{ 652 ddlNode: ddlNode{Db: table.SqlDatabase}, 653 Table: table, 654 ColumnName: columnName, 655 NewColumnName: newColumnName, 656 } 657 } 658 659 func NewRenameColumn(database sql.Database, table *UnresolvedTable, columnName string, newColumnName string) *RenameColumn { 660 return &RenameColumn{ 661 ddlNode: ddlNode{Db: database}, 662 Table: table, 663 ColumnName: columnName, 664 NewColumnName: newColumnName, 665 } 666 } 667 668 func (r *RenameColumn) Checks() sql.CheckConstraints { 669 return r.checks 670 } 671 672 func (r *RenameColumn) WithChecks(checks sql.CheckConstraints) sql.Node { 673 ret := *r 674 ret.checks = checks 675 return &ret 676 } 677 678 func (r *RenameColumn) WithDatabase(db sql.Database) (sql.Node, error) { 679 nr := *r 680 nr.Db = db 681 return &nr, nil 682 } 683 684 func (r RenameColumn) WithTargetSchema(schema sql.Schema) (sql.Node, error) { 685 r.targetSchema = schema 686 return &r, nil 687 } 688 689 func (r *RenameColumn) TargetSchema() sql.Schema { 690 return r.targetSchema 691 } 692 693 func (r *RenameColumn) String() string { 694 return fmt.Sprintf("rename column %s to %s", r.ColumnName, r.NewColumnName) 695 } 696 697 func (r *RenameColumn) IsReadOnly() bool { 698 return false 699 } 700 701 func (r *RenameColumn) DebugString() string { 702 pr := sql.NewTreePrinter() 703 pr.WriteNode("rename column %s to %s", r.ColumnName, r.NewColumnName) 704 705 var children []string 706 for _, col := range r.targetSchema { 707 children = append(children, sql.DebugString(col)) 708 } 709 710 pr.WriteChildren(children...) 711 return pr.String() 712 } 713 714 func (r *RenameColumn) Resolved() bool { 715 return r.Table.Resolved() && r.ddlNode.Resolved() && r.targetSchema.Resolved() 716 } 717 718 func (r *RenameColumn) Schema() sql.Schema { 719 return types.OkResultSchema 720 } 721 722 func (r *RenameColumn) Expressions() []sql.Expression { 723 return transform.WrappedColumnDefaults(r.targetSchema) 724 } 725 726 func (r RenameColumn) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { 727 if len(exprs) != len(r.targetSchema) { 728 return nil, sql.ErrInvalidChildrenNumber.New(r, len(exprs), len(r.targetSchema)) 729 } 730 731 sch, err := transform.SchemaWithDefaults(r.targetSchema, exprs) 732 if err != nil { 733 return nil, err 734 } 735 736 r.targetSchema = sch 737 return &r, nil 738 } 739 740 func (r *RenameColumn) Children() []sql.Node { 741 return []sql.Node{r.Table} 742 } 743 744 func (r RenameColumn) WithChildren(children ...sql.Node) (sql.Node, error) { 745 if len(children) != 1 { 746 return nil, sql.ErrInvalidChildrenNumber.New(r, len(children), 1) 747 } 748 r.Table = children[0] 749 return &r, nil 750 } 751 752 // CheckPrivileges implements the interface sql.Node. 753 func (r *RenameColumn) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { 754 subject := sql.PrivilegeCheckSubject{ 755 Database: CheckPrivilegeNameForDatabase(r.Db), 756 Table: getTableName(r.Table), 757 } 758 759 return opChecker.UserHasPrivileges(ctx, 760 sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Alter)) 761 } 762 763 // CollationCoercibility implements the interface sql.CollationCoercible. 764 func (*RenameColumn) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { 765 return sql.Collation_binary, 7 766 } 767 768 type ModifyColumn struct { 769 ddlNode 770 Table sql.Node 771 columnName string 772 column *sql.Column 773 order *sql.ColumnOrder 774 targetSchema sql.Schema 775 } 776 777 var _ sql.Node = (*ModifyColumn)(nil) 778 var _ sql.Expressioner = (*ModifyColumn)(nil) 779 var _ sql.Databaser = (*ModifyColumn)(nil) 780 var _ sql.SchemaTarget = (*ModifyColumn)(nil) 781 var _ sql.CollationCoercible = (*ModifyColumn)(nil) 782 783 func NewModifyColumnResolved(table *ResolvedTable, columnName string, column sql.Column, order *sql.ColumnOrder) *ModifyColumn { 784 column.Source = table.Name() 785 return &ModifyColumn{ 786 ddlNode: ddlNode{Db: table.SqlDatabase}, 787 Table: table, 788 columnName: columnName, 789 column: &column, 790 order: order, 791 } 792 } 793 794 func NewModifyColumn(database sql.Database, table *UnresolvedTable, columnName string, column *sql.Column, order *sql.ColumnOrder) *ModifyColumn { 795 column.Source = table.name 796 return &ModifyColumn{ 797 ddlNode: ddlNode{Db: database}, 798 Table: table, 799 columnName: columnName, 800 column: column, 801 order: order, 802 } 803 } 804 805 func (m *ModifyColumn) WithDatabase(db sql.Database) (sql.Node, error) { 806 nm := *m 807 nm.Db = db 808 return &nm, nil 809 } 810 811 func (m *ModifyColumn) Column() string { 812 return m.columnName 813 } 814 815 func (m *ModifyColumn) NewColumn() *sql.Column { 816 return m.column 817 } 818 819 func (m *ModifyColumn) Order() *sql.ColumnOrder { 820 return m.order 821 } 822 823 // Schema implements the sql.Node interface. 824 func (m *ModifyColumn) Schema() sql.Schema { 825 return types.OkResultSchema 826 } 827 828 func (m *ModifyColumn) String() string { 829 return fmt.Sprintf("modify column %s", m.column.Name) 830 } 831 832 func (m *ModifyColumn) IsReadOnly() bool { 833 return false 834 } 835 836 func (m ModifyColumn) WithTargetSchema(schema sql.Schema) (sql.Node, error) { 837 m.targetSchema = schema 838 return &m, nil 839 } 840 841 func (m *ModifyColumn) TargetSchema() sql.Schema { 842 return m.targetSchema 843 } 844 845 func (m *ModifyColumn) Children() []sql.Node { 846 return []sql.Node{m.Table} 847 } 848 849 func (m ModifyColumn) WithChildren(children ...sql.Node) (sql.Node, error) { 850 if len(children) != 1 { 851 return nil, sql.ErrInvalidChildrenNumber.New(m, len(children), 1) 852 } 853 m.Table = children[0] 854 return &m, nil 855 } 856 857 // CheckPrivileges implements the interface sql.Node. 858 func (m *ModifyColumn) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { 859 subject := sql.PrivilegeCheckSubject{ 860 Database: CheckPrivilegeNameForDatabase(m.Db), 861 Table: getTableName(m.Table), 862 } 863 return opChecker.UserHasPrivileges(ctx, 864 sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Alter)) 865 } 866 867 // CollationCoercibility implements the interface sql.CollationCoercible. 868 func (*ModifyColumn) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { 869 return sql.Collation_binary, 7 870 } 871 872 func (m *ModifyColumn) Expressions() []sql.Expression { 873 return append(transform.WrappedColumnDefaults(m.targetSchema), expression.WrapExpressions(m.column.Default)...) 874 } 875 876 func (m ModifyColumn) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { 877 if len(exprs) != 1+len(m.targetSchema) { 878 return nil, sql.ErrInvalidChildrenNumber.New(m, len(exprs), 1+len(m.targetSchema)) 879 } 880 881 sch, err := transform.SchemaWithDefaults(m.targetSchema, exprs[:len(m.targetSchema)]) 882 if err != nil { 883 return nil, err 884 } 885 m.targetSchema = sch 886 887 unwrappedColDefVal, ok := exprs[len(exprs)-1].(*expression.Wrapper).Unwrap().(*sql.ColumnDefaultValue) 888 if ok { 889 m.column.Default = unwrappedColDefVal 890 } else { // nil fails type check 891 m.column.Default = nil 892 } 893 return &m, nil 894 } 895 896 // Resolved implements the Resolvable interface. 897 func (m *ModifyColumn) Resolved() bool { 898 return m.Table.Resolved() && m.column.Default.Resolved() && m.ddlNode.Resolved() && m.targetSchema.Resolved() 899 } 900 901 func (m *ModifyColumn) ValidateDefaultPosition(tblSch sql.Schema) error { 902 colsBeforeThis := make(map[string]*sql.Column) 903 colsAfterThis := make(map[string]*sql.Column) // includes the modified column 904 if m.order == nil { 905 i := 0 906 for ; i < len(tblSch); i++ { 907 if tblSch[i].Name == m.column.Name { 908 colsAfterThis[m.column.Name] = m.column 909 break 910 } 911 colsBeforeThis[tblSch[i].Name] = tblSch[i] 912 } 913 for ; i < len(tblSch); i++ { 914 colsAfterThis[tblSch[i].Name] = tblSch[i] 915 } 916 } else if m.order.First { 917 for i := 0; i < len(tblSch); i++ { 918 colsAfterThis[tblSch[i].Name] = tblSch[i] 919 } 920 } else { 921 i := 1 922 for ; i < len(tblSch); i++ { 923 colsBeforeThis[tblSch[i].Name] = tblSch[i] 924 if tblSch[i-1].Name == m.order.AfterColumn { 925 break 926 } 927 } 928 for ; i < len(tblSch); i++ { 929 colsAfterThis[tblSch[i].Name] = tblSch[i] 930 } 931 delete(colsBeforeThis, m.column.Name) 932 colsAfterThis[m.column.Name] = m.column 933 } 934 935 err := inspectDefaultForInvalidColumns(m.column, colsAfterThis) 936 if err != nil { 937 return err 938 } 939 thisCol := map[string]*sql.Column{m.column.Name: m.column} 940 for _, colBefore := range colsBeforeThis { 941 err = inspectDefaultForInvalidColumns(colBefore, thisCol) 942 if err != nil { 943 return err 944 } 945 } 946 947 return nil 948 } 949 950 type AlterTableCollation struct { 951 ddlNode 952 Table sql.Node 953 Collation sql.CollationID 954 } 955 956 var _ sql.Node = (*AlterTableCollation)(nil) 957 var _ sql.Databaser = (*AlterTableCollation)(nil) 958 959 // NewAlterTableCollationResolved returns a new *AlterTableCollation 960 func NewAlterTableCollationResolved(table *ResolvedTable, collation sql.CollationID) *AlterTableCollation { 961 return &AlterTableCollation{ 962 ddlNode: ddlNode{Db: table.SqlDatabase}, 963 Table: table, 964 Collation: collation, 965 } 966 } 967 968 // NewAlterTableCollation returns a new *AlterTableCollation 969 func NewAlterTableCollation(database sql.Database, table *UnresolvedTable, collation sql.CollationID) *AlterTableCollation { 970 return &AlterTableCollation{ 971 ddlNode: ddlNode{Db: database}, 972 Table: table, 973 Collation: collation, 974 } 975 } 976 977 // WithDatabase implements the interface sql.Databaser. 978 func (atc *AlterTableCollation) WithDatabase(db sql.Database) (sql.Node, error) { 979 natc := *atc 980 natc.Db = db 981 return &natc, nil 982 } 983 984 func (atc *AlterTableCollation) IsReadOnly() bool { 985 return false 986 } 987 988 // String implements the interface sql.Node. 989 func (atc *AlterTableCollation) String() string { 990 return fmt.Sprintf("alter table %s collate %s", atc.Table.String(), atc.Collation.Name()) 991 } 992 993 // DebugString implements the interface sql.Node. 994 func (atc *AlterTableCollation) DebugString() string { 995 return atc.String() 996 } 997 998 // Resolved implements the interface sql.Node. 999 func (atc *AlterTableCollation) Resolved() bool { 1000 return atc.Table.Resolved() && atc.ddlNode.Resolved() 1001 } 1002 1003 // Schema implements the interface sql.Node. 1004 func (atc *AlterTableCollation) Schema() sql.Schema { 1005 return types.OkResultSchema 1006 } 1007 1008 // Children implements the interface sql.Node. 1009 func (atc *AlterTableCollation) Children() []sql.Node { 1010 return []sql.Node{atc.Table} 1011 } 1012 1013 // WithChildren implements the interface sql.Node. 1014 func (atc *AlterTableCollation) WithChildren(children ...sql.Node) (sql.Node, error) { 1015 if len(children) != 1 { 1016 return nil, sql.ErrInvalidChildrenNumber.New(atc, len(children), 1) 1017 } 1018 natc := *atc 1019 natc.Table = children[0] 1020 return &natc, nil 1021 } 1022 1023 // CheckPrivileges implements the interface sql.Node. 1024 func (atc *AlterTableCollation) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { 1025 subject := sql.PrivilegeCheckSubject{ 1026 Database: CheckPrivilegeNameForDatabase(atc.Db), 1027 Table: getTableName(atc.Table), 1028 } 1029 1030 return opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Alter)) 1031 }