github.com/dolthub/go-mysql-server@v0.18.0/sql/planbuilder/ddl.go (about) 1 // Copyright 2023 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package planbuilder 16 17 import ( 18 "fmt" 19 "strconv" 20 "strings" 21 22 "github.com/dolthub/vitess/go/mysql" 23 ast "github.com/dolthub/vitess/go/vt/sqlparser" 24 25 "github.com/dolthub/go-mysql-server/sql" 26 "github.com/dolthub/go-mysql-server/sql/expression" 27 "github.com/dolthub/go-mysql-server/sql/expression/function" 28 "github.com/dolthub/go-mysql-server/sql/mysql_db" 29 "github.com/dolthub/go-mysql-server/sql/plan" 30 "github.com/dolthub/go-mysql-server/sql/types" 31 ) 32 33 func (b *Builder) resolveDb(name string) sql.Database { 34 if name == "" { 35 err := sql.ErrNoDatabaseSelected.New() 36 b.handleErr(err) 37 } 38 database, err := b.cat.Database(b.ctx, name) 39 if err != nil { 40 b.handleErr(err) 41 } 42 43 // todo show tables as of expects privileged 44 //if privilegedDatabase, ok := database.(mysql_db.PrivilegedDatabase); ok { 45 // database = privilegedDatabase.Unwrap() 46 //} 47 return database 48 } 49 50 // buildAlterTable converts AlterTable AST nodes. If there is a single clause in the statement, it is returned as 51 // the appropriate node type. Otherwise, a plan.Block is returned with children representing all the various clauses. 52 // Our validation rules for what counts as a legal set of alter clauses differs from mysql's here. MySQL seems to apply 53 // some form of precedence rules to the clauses in an ALTER TABLE so that e.g. DROP COLUMN always happens before other 54 // kinds of statements. So in MySQL, statements like `ALTER TABLE t ADD KEY (a), DROP COLUMN a` fails, whereas our 55 // analyzer happily produces a plan that adds an index and then drops that column. We do this in part for simplicity, 56 // and also because we construct more than one node per clause in some cases and really want them executed in a 57 // particular order in that case. 58 func (b *Builder) buildAlterTable(inScope *scope, query string, c *ast.AlterTable) (outScope *scope) { 59 b.multiDDL = true 60 defer func() { 61 b.multiDDL = false 62 }() 63 64 statements := make([]sql.Node, 0, len(c.Statements)) 65 for i := 0; i < len(c.Statements); i++ { 66 scopes := b.buildAlterTableClause(inScope, c.Statements[i]) 67 for _, scope := range scopes { 68 statements = append(statements, scope.node) 69 } 70 } 71 72 if len(statements) == 1 { 73 outScope = inScope.push() 74 outScope.node = statements[0] 75 return outScope 76 } 77 78 outScope = inScope.push() 79 outScope.node = plan.NewBlock(statements) 80 return 81 } 82 83 func (b *Builder) buildDDL(inScope *scope, query string, c *ast.DDL) (outScope *scope) { 84 outScope = inScope.push() 85 switch strings.ToLower(c.Action) { 86 case ast.CreateStr: 87 if c.TriggerSpec != nil { 88 return b.buildCreateTrigger(inScope, query, c) 89 } 90 if c.ProcedureSpec != nil { 91 return b.buildCreateProcedure(inScope, query, c) 92 } 93 if c.EventSpec != nil { 94 return b.buildCreateEvent(inScope, query, c) 95 } 96 if c.ViewSpec != nil { 97 return b.buildCreateView(inScope, query, c) 98 } 99 return b.buildCreateTable(inScope, c) 100 case ast.DropStr: 101 // get database 102 if c.TriggerSpec != nil { 103 dbName := c.TriggerSpec.TrigName.Qualifier.String() 104 if dbName == "" { 105 dbName = b.ctx.GetCurrentDatabase() 106 } 107 trigName := c.TriggerSpec.TrigName.Name.String() 108 outScope.node = plan.NewDropTrigger(b.resolveDb(dbName), trigName, c.IfExists) 109 return 110 } 111 if c.ProcedureSpec != nil { 112 dbName := c.ProcedureSpec.ProcName.Qualifier.String() 113 if dbName == "" { 114 dbName = b.ctx.GetCurrentDatabase() 115 } 116 procName := c.ProcedureSpec.ProcName.Name.String() 117 outScope.node = plan.NewDropProcedure(b.resolveDb(dbName), procName, c.IfExists) 118 return 119 } 120 if c.EventSpec != nil { 121 dbName := c.EventSpec.EventName.Qualifier.String() 122 if dbName == "" { 123 dbName = b.ctx.GetCurrentDatabase() 124 } 125 eventName := c.EventSpec.EventName.Name.String() 126 outScope.node = plan.NewDropEvent(b.resolveDb(dbName), eventName, c.IfExists) 127 return 128 } 129 if len(c.FromViews) != 0 { 130 plans := make([]sql.Node, len(c.FromViews)) 131 for i, v := range c.FromViews { 132 plans[i] = plan.NewSingleDropView(b.currentDb(), v.Name.String()) 133 } 134 outScope.node = plan.NewDropView(plans, c.IfExists) 135 return 136 } 137 return b.buildDropTable(inScope, c) 138 case ast.AlterStr: 139 if c.EventSpec != nil { 140 return b.buildAlterEvent(inScope, query, c) 141 } else if !c.User.IsEmpty() { 142 return b.buildAlterUser(inScope, query, c) 143 } 144 b.handleErr(sql.ErrUnsupportedFeature.New(ast.String(c))) 145 case ast.RenameStr: 146 return b.buildRenameTable(inScope, c) 147 case ast.TruncateStr: 148 return b.buildTruncateTable(inScope, c) 149 default: 150 b.handleErr(sql.ErrUnsupportedSyntax.New(ast.String(c))) 151 } 152 return 153 } 154 155 func (b *Builder) buildDropTable(inScope *scope, c *ast.DDL) (outScope *scope) { 156 outScope = inScope.push() 157 var dropTables []sql.Node 158 dbName := c.FromTables[0].Qualifier.String() 159 if dbName == "" { 160 dbName = b.currentDb().Name() 161 } 162 for _, t := range c.FromTables { 163 if t.Qualifier.String() != "" && t.Qualifier.String() != dbName { 164 err := sql.ErrUnsupportedFeature.New("dropping tables on multiple databases in the same statement") 165 b.handleErr(err) 166 } 167 tableName := strings.ToLower(t.Name.String()) 168 if c.IfExists { 169 _, _, err := b.cat.Table(b.ctx, dbName, tableName) 170 if sql.ErrTableNotFound.Is(err) && b.ctx != nil && b.ctx.Session != nil { 171 b.ctx.Session.Warn(&sql.Warning{ 172 Level: "Note", 173 Code: mysql.ERBadTable, 174 Message: fmt.Sprintf("Unknown table '%s'", tableName), 175 }) 176 continue 177 } else if err != nil { 178 b.handleErr(err) 179 } 180 } 181 182 tableScope, ok := b.buildResolvedTable(inScope, dbName, tableName, nil) 183 if ok { 184 dropTables = append(dropTables, tableScope.node) 185 } else if !c.IfExists { 186 err := sql.ErrTableNotFound.New(tableName) 187 b.handleErr(err) 188 } 189 } 190 191 outScope.node = plan.NewDropTable(dropTables, c.IfExists) 192 return 193 } 194 195 func (b *Builder) buildTruncateTable(inScope *scope, c *ast.DDL) (outScope *scope) { 196 outScope = inScope.push() 197 dbName := c.Table.Qualifier.String() 198 tabName := c.Table.Name.String() 199 tableScope, ok := b.buildResolvedTable(inScope, dbName, tabName, nil) 200 if !ok { 201 b.handleErr(sql.ErrTableNotFound.New(tabName)) 202 } 203 outScope.node = plan.NewTruncate( 204 c.Table.Qualifier.String(), 205 tableScope.node, 206 ) 207 return 208 } 209 210 func (b *Builder) buildCreateTable(inScope *scope, c *ast.DDL) (outScope *scope) { 211 outScope = inScope.push() 212 if c.OptLike != nil { 213 return b.buildCreateTableLike(inScope, c) 214 } 215 216 qualifier := c.Table.Qualifier.String() 217 if qualifier == "" { 218 qualifier = b.ctx.GetCurrentDatabase() 219 } 220 database := b.resolveDb(qualifier) 221 222 // In the case that no table spec is given but a SELECT Statement return the CREATE TABLE node. 223 // if the table spec != nil it will get parsed below. 224 if c.TableSpec == nil && c.OptSelect != nil { 225 tableSpec := &plan.TableSpec{} 226 227 selectScope := b.buildSelectStmt(inScope, c.OptSelect.Select) 228 229 outScope.node = plan.NewCreateTableSelect(database, c.Table.Name.String(), selectScope.node, tableSpec, plan.IfNotExistsOption(c.IfNotExists), plan.TempTableOption(c.Temporary)) 230 return outScope 231 } 232 233 idxDefs := b.buildIndexDefs(inScope, c.TableSpec) 234 235 schema, collation, comment := b.tableSpecToSchema(inScope, outScope, database, strings.ToLower(c.Table.Name.String()), c.TableSpec, false) 236 fkDefs, chDefs := b.buildConstraintsDefs(outScope, c.Table, c.TableSpec) 237 238 schema.Schema = assignColumnIndexesInSchema(schema.Schema) 239 chDefs = assignColumnIndexesInCheckDefs(chDefs, schema.Schema) 240 241 if privDb, ok := database.(mysql_db.PrivilegedDatabase); ok { 242 if sv, ok := privDb.Unwrap().(sql.SchemaValidator); ok { 243 if err := sv.ValidateSchema(schema.PhysicalSchema()); err != nil { 244 b.handleErr(err) 245 } 246 } 247 } else { 248 if sv, ok := database.(sql.SchemaValidator); ok { 249 if err := sv.ValidateSchema(schema.PhysicalSchema()); err != nil { 250 b.handleErr(err) 251 } 252 } 253 } 254 255 tableSpec := &plan.TableSpec{ 256 Schema: schema, 257 IdxDefs: idxDefs, 258 FkDefs: fkDefs, 259 ChDefs: chDefs, 260 Collation: collation, 261 Comment: comment, 262 } 263 264 if c.OptSelect != nil { 265 selectScope := b.buildSelectStmt(inScope, c.OptSelect.Select) 266 outScope.node = plan.NewCreateTableSelect(database, c.Table.Name.String(), selectScope.node, tableSpec, plan.IfNotExistsOption(c.IfNotExists), plan.TempTableOption(c.Temporary)) 267 } else { 268 outScope.node = plan.NewCreateTable( 269 database, c.Table.Name.String(), plan.IfNotExistsOption(c.IfNotExists), plan.TempTableOption(c.Temporary), tableSpec) 270 } 271 272 return 273 } 274 275 func assignColumnIndexesInCheckDefs(defs []*sql.CheckConstraint, schema sql.Schema) []*sql.CheckConstraint { 276 newDefs := make([]*sql.CheckConstraint, len(defs)) 277 for i, def := range defs { 278 newDefs[i] = def 279 newDefs[i].Expr = assignColumnIndexes(def.Expr, schema).(sql.Expression) 280 } 281 return newDefs 282 } 283 284 func assignColumnIndexesInSchema(schema sql.Schema) sql.Schema { 285 newSch := make(sql.Schema, len(schema)) 286 for i, col := range schema { 287 newSch[i] = col 288 if col.Default != nil { 289 newSch[i].Default = assignColumnIndexes(col.Default, schema).(*sql.ColumnDefaultValue) 290 } 291 if col.Generated != nil { 292 newSch[i].Generated = assignColumnIndexes(col.Generated, schema).(*sql.ColumnDefaultValue) 293 } 294 } 295 return newSch 296 } 297 298 func (b *Builder) buildCreateTableLike(inScope *scope, ct *ast.DDL) *scope { 299 tableName := ct.OptLike.LikeTable.Name.String() 300 likeDbName := ct.OptLike.LikeTable.Qualifier.String() 301 if likeDbName == "" { 302 likeDbName = b.ctx.GetCurrentDatabase() 303 } 304 outScope, ok := b.buildTablescan(inScope, likeDbName, tableName, nil) 305 if !ok { 306 b.handleErr(sql.ErrTableNotFound.New(tableName)) 307 } 308 likeTable, ok := outScope.node.(*plan.ResolvedTable) 309 if !ok { 310 err := fmt.Errorf("expected resolved table: %s", tableName) 311 b.handleErr(err) 312 } 313 314 newTableName := strings.ToLower(ct.Table.Name.String()) 315 outScope.setTableAlias(newTableName) 316 317 var idxDefs []*plan.IndexDefinition 318 if indexableTable, ok := likeTable.Table.(sql.IndexAddressableTable); ok { 319 indexes, err := indexableTable.GetIndexes(b.ctx) 320 if err != nil { 321 b.handleErr(err) 322 } 323 for _, index := range indexes { 324 if index.IsGenerated() { 325 continue 326 } 327 constraint := sql.IndexConstraint_None 328 if index.IsUnique() { 329 if index.ID() == "PRIMARY" { 330 constraint = sql.IndexConstraint_Primary 331 } else { 332 constraint = sql.IndexConstraint_Unique 333 } 334 } 335 336 columns := make([]sql.IndexColumn, len(index.Expressions())) 337 for i, col := range index.Expressions() { 338 //TODO: find a better way to get only the column name if the table is present 339 col = strings.TrimPrefix(col, indexableTable.Name()+".") 340 columns[i] = sql.IndexColumn{ 341 Name: col, 342 Length: 0, 343 } 344 } 345 idxDefs = append(idxDefs, &plan.IndexDefinition{ 346 IndexName: index.ID(), 347 Using: sql.IndexUsing_Default, 348 Constraint: constraint, 349 Columns: columns, 350 Comment: index.Comment(), 351 }) 352 } 353 } 354 origSch := likeTable.Schema() 355 newSch := make(sql.Schema, len(origSch)) 356 for i, col := range origSch { 357 tempCol := *col 358 tempCol.Source = newTableName 359 newSch[i] = &tempCol 360 } 361 362 var pkOrdinals []int 363 if pkTable, ok := likeTable.Table.(sql.PrimaryKeyTable); ok { 364 pkOrdinals = pkTable.PrimaryKeySchema().PkOrdinals 365 } 366 367 var checkDefs []*sql.CheckConstraint 368 if checksTable, ok := likeTable.Table.(sql.CheckTable); ok { 369 checks, err := checksTable.GetChecks(b.ctx) 370 if err != nil { 371 b.handleErr(err) 372 } 373 374 for _, check := range checks { 375 checkConstraint := b.buildCheckConstraint(outScope, &check) 376 if err != nil { 377 b.handleErr(err) 378 } 379 380 // Prevent a name collision between old and new checks. 381 // New check will be assigned a name during building. 382 checkConstraint.Name = "" 383 checkDefs = append(checkDefs, checkConstraint) 384 } 385 } 386 387 pkSchema := sql.NewPrimaryKeySchema(newSch, pkOrdinals...) 388 pkSchema.Schema = b.resolveSchemaDefaults(outScope, pkSchema.Schema) 389 390 tableSpec := &plan.TableSpec{ 391 Schema: pkSchema, 392 IdxDefs: idxDefs, 393 ChDefs: checkDefs, 394 Collation: likeTable.Collation(), 395 Comment: likeTable.Comment(), 396 } 397 398 qualifier := ct.Table.Qualifier.String() 399 if qualifier == "" { 400 qualifier = b.ctx.GetCurrentDatabase() 401 } 402 database := b.resolveDb(qualifier) 403 404 outScope.node = plan.NewCreateTable(database, newTableName, plan.IfNotExistsOption(ct.IfNotExists), plan.TempTableOption(ct.Temporary), tableSpec) 405 return outScope 406 } 407 408 func (b *Builder) buildRenameTable(inScope *scope, ddl *ast.DDL) (outScope *scope) { 409 outScope = inScope 410 if len(ddl.FromTables) != len(ddl.ToTables) { 411 panic("Expected from tables and to tables of equal length") 412 } 413 414 var fromTables, toTables []string 415 for _, table := range ddl.FromTables { 416 fromTables = append(fromTables, table.Name.String()) 417 } 418 for _, table := range ddl.ToTables { 419 toTables = append(toTables, table.Name.String()) 420 } 421 422 outScope.node = plan.NewRenameTable(b.currentDb(), fromTables, toTables, b.multiDDL) 423 return 424 } 425 426 func (b *Builder) isUniqueColumn(tableSpec *ast.TableSpec, columnName string) bool { 427 for _, column := range tableSpec.Columns { 428 if column.Name.String() == columnName { 429 return column.Type.KeyOpt == colKeyUnique || 430 column.Type.KeyOpt == colKeyUniqueKey 431 } 432 } 433 err := fmt.Errorf("unknown column name %s", columnName) 434 b.handleErr(err) 435 return false 436 437 } 438 439 func (b *Builder) buildAlterTableClause(inScope *scope, ddl *ast.DDL) []*scope { 440 outScopes := make([]*scope, 0, 1) 441 442 // RENAME a to b, c to d .. 443 if ddl.Action == ast.RenameStr { 444 outScopes = append(outScopes, b.buildRenameTable(inScope, ddl)) 445 } else { 446 dbName := ddl.Table.Qualifier.String() 447 tableName := ddl.Table.Name.String() 448 var ok bool 449 tableScope, ok := b.buildResolvedTable(inScope, dbName, tableName, nil) 450 if !ok { 451 b.handleErr(sql.ErrTableNotFound.New(tableName)) 452 } 453 rt, ok := tableScope.node.(*plan.ResolvedTable) 454 if !ok { 455 err := fmt.Errorf("expected resolved table: %s", tableName) 456 b.handleErr(err) 457 } 458 459 if ddl.ColumnAction != "" { 460 columnActionOutscope := b.buildAlterTableColumnAction(tableScope, ddl, rt) 461 outScopes = append(outScopes, columnActionOutscope) 462 463 if ddl.TableSpec != nil { 464 if len(ddl.TableSpec.Columns) != 1 { 465 err := sql.ErrUnsupportedFeature.New("unexpected number of columns in a single alter column clause") 466 b.handleErr(err) 467 } 468 469 column := ddl.TableSpec.Columns[0] 470 isUnique := b.isUniqueColumn(ddl.TableSpec, column.Name.String()) 471 if isUnique { 472 createIndex := plan.NewAlterCreateIndex( 473 rt.Database(), 474 rt, 475 column.Name.String(), 476 sql.IndexUsing_BTree, 477 sql.IndexConstraint_Unique, 478 []sql.IndexColumn{{Name: column.Name.String()}}, 479 "", 480 ) 481 482 createIndexScope := inScope.push() 483 createIndexScope.node = createIndex 484 outScopes = append(outScopes, createIndexScope) 485 } 486 } 487 } 488 489 if ddl.ConstraintAction != "" { 490 if len(ddl.TableSpec.Constraints) != 1 { 491 b.handleErr(sql.ErrUnsupportedFeature.New("unexpected number of constraints in a single alter constraint clause")) 492 } 493 outScopes = append(outScopes, b.buildAlterConstraint(tableScope, ddl, rt)) 494 } 495 496 if ddl.IndexSpec != nil { 497 outScopes = append(outScopes, b.buildAlterIndex(tableScope, ddl, rt)) 498 } 499 500 if ddl.AutoIncSpec != nil { 501 outScopes = append(outScopes, b.buildAlterAutoIncrement(tableScope, ddl, rt)) 502 } 503 504 if ddl.DefaultSpec != nil { 505 outScopes = append(outScopes, b.buildAlterDefault(tableScope, ddl, rt)) 506 } 507 508 if ddl.AlterCollationSpec != nil { 509 outScopes = append(outScopes, b.buildAlterCollationSpec(tableScope, ddl, rt)) 510 } 511 512 for _, s := range outScopes { 513 if ts, ok := s.node.(sql.SchemaTarget); ok { 514 s.node = b.modifySchemaTarget(s, ts, rt) 515 } 516 } 517 pkt, _ := rt.Table.(sql.PrimaryKeyTable) 518 if pkt != nil { 519 for _, s := range outScopes { 520 if ts, ok := s.node.(sql.PrimaryKeySchemaTarget); ok { 521 s.node = b.modifySchemaTarget(inScope, ts, rt) 522 ts.WithPrimaryKeySchema(pkt.PrimaryKeySchema()) 523 } 524 } 525 } 526 } 527 return outScopes 528 } 529 530 func (b *Builder) buildAlterTableColumnAction(inScope *scope, ddl *ast.DDL, table *plan.ResolvedTable) (outScope *scope) { 531 outScope = inScope 532 switch strings.ToLower(ddl.ColumnAction) { 533 case ast.AddStr: 534 sch, _, _ := b.tableSpecToSchema(inScope, outScope, table.Database(), ddl.Table.Name.String(), ddl.TableSpec, true) 535 outScope.node = plan.NewAddColumnResolved(table, *sch.Schema[0], columnOrderToColumnOrder(ddl.ColumnOrder)) 536 case ast.DropStr: 537 drop := plan.NewDropColumnResolved(table, ddl.Column.String()) 538 checks := b.loadChecksFromTable(outScope, table.Table) 539 outScope.node = drop.WithChecks(checks) 540 case ast.RenameStr: 541 rename := plan.NewRenameColumnResolved(table, ddl.Column.String(), ddl.ToColumn.String()) 542 checks := b.loadChecksFromTable(outScope, table.Table) 543 outScope.node = rename.WithChecks(checks) 544 case ast.ModifyStr, ast.ChangeStr: 545 // modify adds a new column maybe with same name 546 // make new hierarchy so it resolves before old column 547 outScope = inScope.push() 548 sch, _, _ := b.tableSpecToSchema(inScope, outScope, table.Database(), ddl.Table.Name.String(), ddl.TableSpec, true) 549 modifyCol := plan.NewModifyColumnResolved(table, ddl.Column.String(), *sch.Schema[0], columnOrderToColumnOrder(ddl.ColumnOrder)) 550 outScope.node = modifyCol 551 default: 552 err := sql.ErrUnsupportedFeature.New(ast.String(ddl)) 553 b.handleErr(err) 554 } 555 556 return outScope 557 } 558 559 func (b *Builder) buildAlterConstraint(inScope *scope, ddl *ast.DDL, table *plan.ResolvedTable) (outScope *scope) { 560 outScope = inScope 561 parsedConstraint := b.convertConstraintDefinition(inScope, ddl.TableSpec.Constraints[0]) 562 switch strings.ToLower(ddl.ConstraintAction) { 563 case ast.AddStr: 564 switch c := parsedConstraint.(type) { 565 case *sql.ForeignKeyConstraint: 566 c.Database = table.SqlDatabase.Name() 567 c.Table = table.Name() 568 alterFk := plan.NewAlterAddForeignKey(c) 569 alterFk.DbProvider = b.cat 570 outScope.node = alterFk 571 case *sql.CheckConstraint: 572 outScope.node = plan.NewAlterAddCheck(table, c) 573 default: 574 err := sql.ErrUnsupportedFeature.New(ast.String(ddl)) 575 b.handleErr(err) 576 } 577 case ast.DropStr: 578 switch c := parsedConstraint.(type) { 579 case *sql.ForeignKeyConstraint: 580 database := table.SqlDatabase.Name() 581 dropFk := plan.NewAlterDropForeignKey(database, table.Name(), c.Name) 582 dropFk.DbProvider = b.cat 583 outScope.node = dropFk 584 case *sql.CheckConstraint: 585 outScope.node = plan.NewAlterDropCheck(table, c.Name) 586 case namedConstraint: 587 outScope.node = &plan.DropConstraint{ 588 UnaryNode: plan.UnaryNode{Child: table}, 589 Name: c.name, 590 } 591 default: 592 err := sql.ErrUnsupportedFeature.New(ast.String(ddl)) 593 b.handleErr(err) 594 } 595 } 596 return 597 } 598 599 func (b *Builder) buildConstraintsDefs(inScope *scope, tname ast.TableName, spec *ast.TableSpec) (fks []*sql.ForeignKeyConstraint, checks []*sql.CheckConstraint) { 600 for _, unknownConstraint := range spec.Constraints { 601 parsedConstraint := b.convertConstraintDefinition(inScope, unknownConstraint) 602 switch constraint := parsedConstraint.(type) { 603 case *sql.ForeignKeyConstraint: 604 constraint.Database = tname.Qualifier.String() 605 constraint.Table = tname.Name.String() 606 if constraint.Database == "" { 607 constraint.Database = b.ctx.GetCurrentDatabase() 608 } 609 fks = append(fks, constraint) 610 case *sql.CheckConstraint: 611 checks = append(checks, constraint) 612 default: 613 err := sql.ErrUnknownConstraintDefinition.New(unknownConstraint.Name, unknownConstraint) 614 b.handleErr(err) 615 } 616 } 617 return 618 } 619 620 func columnOrderToColumnOrder(order *ast.ColumnOrder) *sql.ColumnOrder { 621 if order == nil { 622 return nil 623 } 624 if order.First { 625 return &sql.ColumnOrder{First: true} 626 } else { 627 return &sql.ColumnOrder{AfterColumn: order.AfterColumn.String()} 628 } 629 } 630 631 func (b *Builder) buildIndexDefs(inScope *scope, spec *ast.TableSpec) (idxDefs []*plan.IndexDefinition) { 632 for _, idxDef := range spec.Indexes { 633 constraint := sql.IndexConstraint_None 634 if idxDef.Info.Primary { 635 constraint = sql.IndexConstraint_Primary 636 } else if idxDef.Info.Unique { 637 constraint = sql.IndexConstraint_Unique 638 } else if idxDef.Info.Spatial { 639 constraint = sql.IndexConstraint_Spatial 640 } else if idxDef.Info.Fulltext { 641 constraint = sql.IndexConstraint_Fulltext 642 } 643 644 columns := b.gatherIndexColumns(idxDef.Columns) 645 646 var comment string 647 for _, option := range idxDef.Options { 648 if strings.ToLower(option.Name) == strings.ToLower(ast.KeywordString(ast.COMMENT_KEYWORD)) { 649 comment = string(option.Value.Val) 650 } 651 } 652 idxDefs = append(idxDefs, &plan.IndexDefinition{ 653 IndexName: idxDef.Info.Name.String(), 654 Using: sql.IndexUsing_Default, //TODO: add vitess support for USING 655 Constraint: constraint, 656 Columns: columns, 657 Comment: comment, 658 }) 659 } 660 661 for _, colDef := range spec.Columns { 662 if colDef.Type.KeyOpt == colKeyFulltextKey { 663 idxDefs = append(idxDefs, &plan.IndexDefinition{ 664 IndexName: "", 665 Using: sql.IndexUsing_Default, 666 Constraint: sql.IndexConstraint_Fulltext, 667 Comment: "", 668 Columns: []sql.IndexColumn{{ 669 Name: colDef.Name.String(), 670 Length: 0, 671 }}, 672 }) 673 } else if colDef.Type.KeyOpt == colKeyUnique || colDef.Type.KeyOpt == colKeyUniqueKey { 674 idxDefs = append(idxDefs, &plan.IndexDefinition{ 675 IndexName: "", 676 Using: sql.IndexUsing_Default, 677 Constraint: sql.IndexConstraint_Unique, 678 Comment: "", 679 Columns: []sql.IndexColumn{{ 680 Name: colDef.Name.String(), 681 Length: 0, 682 }}, 683 }) 684 } 685 } 686 return 687 } 688 689 type namedConstraint struct { 690 name string 691 } 692 693 func (b *Builder) convertConstraintDefinition(inScope *scope, cd *ast.ConstraintDefinition) interface{} { 694 if fkConstraint, ok := cd.Details.(*ast.ForeignKeyDefinition); ok { 695 columns := make([]string, len(fkConstraint.Source)) 696 for i, col := range fkConstraint.Source { 697 columns[i] = col.String() 698 } 699 refColumns := make([]string, len(fkConstraint.ReferencedColumns)) 700 for i, col := range fkConstraint.ReferencedColumns { 701 refColumns[i] = col.String() 702 } 703 refDatabase := fkConstraint.ReferencedTable.Qualifier.String() 704 if refDatabase == "" { 705 refDatabase = b.ctx.GetCurrentDatabase() 706 } 707 // The database and table are set in the calling function 708 return &sql.ForeignKeyConstraint{ 709 Name: cd.Name, 710 Columns: columns, 711 ParentDatabase: refDatabase, 712 ParentTable: fkConstraint.ReferencedTable.Name.String(), 713 ParentColumns: refColumns, 714 OnUpdate: b.buildReferentialAction(fkConstraint.OnUpdate), 715 OnDelete: b.buildReferentialAction(fkConstraint.OnDelete), 716 IsResolved: false, 717 } 718 } else if chConstraint, ok := cd.Details.(*ast.CheckConstraintDefinition); ok { 719 var c sql.Expression 720 if chConstraint.Expr != nil { 721 c = b.buildScalar(inScope, chConstraint.Expr) 722 } 723 724 return &sql.CheckConstraint{ 725 Name: cd.Name, 726 Expr: c, 727 Enforced: chConstraint.Enforced, 728 } 729 } else if len(cd.Name) > 0 && cd.Details == nil { 730 return namedConstraint{cd.Name} 731 } 732 err := sql.ErrUnknownConstraintDefinition.New(cd.Name, cd) 733 b.handleErr(err) 734 return nil 735 } 736 737 func (b *Builder) buildReferentialAction(action ast.ReferenceAction) sql.ForeignKeyReferentialAction { 738 switch action { 739 case ast.Restrict: 740 return sql.ForeignKeyReferentialAction_Restrict 741 case ast.Cascade: 742 return sql.ForeignKeyReferentialAction_Cascade 743 case ast.NoAction: 744 return sql.ForeignKeyReferentialAction_NoAction 745 case ast.SetNull: 746 return sql.ForeignKeyReferentialAction_SetNull 747 case ast.SetDefault: 748 return sql.ForeignKeyReferentialAction_SetDefault 749 default: 750 return sql.ForeignKeyReferentialAction_DefaultAction 751 } 752 } 753 754 func (b *Builder) buildAlterIndex(inScope *scope, ddl *ast.DDL, table *plan.ResolvedTable) (outScope *scope) { 755 outScope = inScope 756 switch strings.ToLower(ddl.IndexSpec.Action) { 757 case ast.CreateStr: 758 var using sql.IndexUsing 759 switch ddl.IndexSpec.Using.Lowered() { 760 case "", "btree": 761 using = sql.IndexUsing_BTree 762 case "hash": 763 using = sql.IndexUsing_Hash 764 default: 765 return b.buildExternalCreateIndex(inScope, ddl) 766 } 767 768 var constraint sql.IndexConstraint 769 switch ddl.IndexSpec.Type { 770 case ast.UniqueStr: 771 constraint = sql.IndexConstraint_Unique 772 case ast.FulltextStr: 773 constraint = sql.IndexConstraint_Fulltext 774 case ast.SpatialStr: 775 constraint = sql.IndexConstraint_Spatial 776 case ast.PrimaryStr: 777 constraint = sql.IndexConstraint_Primary 778 default: 779 constraint = sql.IndexConstraint_None 780 } 781 782 columns := b.gatherIndexColumns(ddl.IndexSpec.Columns) 783 784 var comment string 785 for _, option := range ddl.IndexSpec.Options { 786 if strings.ToLower(option.Name) == strings.ToLower(ast.KeywordString(ast.COMMENT_KEYWORD)) { 787 comment = string(option.Value.Val) 788 } 789 } 790 791 if constraint == sql.IndexConstraint_Primary { 792 outScope.node = plan.NewAlterCreatePk(table.SqlDatabase, table, columns) 793 return 794 } 795 796 indexName := ddl.IndexSpec.ToName.String() 797 if strings.ToLower(indexName) == ast.PrimaryStr { 798 err := sql.ErrInvalidIndexName.New(indexName) 799 b.handleErr(err) 800 } 801 802 createIndex := plan.NewAlterCreateIndex(table.SqlDatabase, table, ddl.IndexSpec.ToName.String(), using, constraint, columns, comment) 803 outScope.node = b.modifySchemaTarget(inScope, createIndex, table) 804 return 805 case ast.DropStr: 806 if ddl.IndexSpec.Type == ast.PrimaryStr { 807 outScope.node = plan.NewAlterDropPk(table.SqlDatabase, table) 808 return 809 } 810 outScope.node = plan.NewAlterDropIndex(table.Database(), table, ddl.IndexSpec.ToName.String()) 811 return 812 case ast.RenameStr: 813 outScope.node = plan.NewAlterRenameIndex(table.Database(), table, ddl.IndexSpec.FromName.String(), ddl.IndexSpec.ToName.String()) 814 return 815 case "disable": 816 outScope.node = plan.NewAlterDisableEnableKeys(table.SqlDatabase, table, true) 817 return 818 case "enable": 819 outScope.node = plan.NewAlterDisableEnableKeys(table.SqlDatabase, table, false) 820 return 821 default: 822 err := sql.ErrUnsupportedFeature.New(ast.String(ddl)) 823 b.handleErr(err) 824 } 825 return 826 } 827 828 func (b *Builder) gatherIndexColumns(cols []*ast.IndexColumn) []sql.IndexColumn { 829 out := make([]sql.IndexColumn, len(cols)) 830 for i, col := range cols { 831 var length int64 832 var err error 833 if col.Length != nil && col.Length.Type == ast.IntVal { 834 length, err = strconv.ParseInt(string(col.Length.Val), 10, 64) 835 if err != nil { 836 b.handleErr(err) 837 } 838 if length < 1 { 839 err := sql.ErrKeyZero.New(col.Column) 840 b.handleErr(err) 841 } 842 } 843 out[i] = sql.IndexColumn{ 844 Name: col.Column.String(), 845 Length: length, 846 } 847 } 848 return out 849 } 850 851 func (b *Builder) buildAlterAutoIncrement(inScope *scope, ddl *ast.DDL, table *plan.ResolvedTable) (outScope *scope) { 852 outScope = inScope 853 val, ok := ddl.AutoIncSpec.Value.(*ast.SQLVal) 854 if !ok { 855 err := sql.ErrInvalidSQLValType.New(ddl.AutoIncSpec.Value) 856 b.handleErr(err) 857 } 858 859 var autoVal uint64 860 if val.Type == ast.IntVal { 861 i, err := strconv.ParseUint(string(val.Val), 10, 64) 862 if err != nil { 863 b.handleErr(err) 864 } 865 autoVal = i 866 } else if val.Type == ast.FloatVal { 867 f, err := strconv.ParseFloat(string(val.Val), 10) 868 if err != nil { 869 b.handleErr(err) 870 } 871 autoVal = uint64(f) 872 } else { 873 err := sql.ErrInvalidSQLValType.New(ddl.AutoIncSpec.Value) 874 b.handleErr(err) 875 } 876 877 outScope.node = plan.NewAlterAutoIncrement(table.Database(), table, autoVal) 878 return 879 } 880 881 func (b *Builder) buildAlterDefault(inScope *scope, ddl *ast.DDL, table *plan.ResolvedTable) (outScope *scope) { 882 outScope = inScope 883 switch strings.ToLower(ddl.DefaultSpec.Action) { 884 case ast.SetStr: 885 for _, c := range table.Schema() { 886 if strings.EqualFold(c.Name, ddl.DefaultSpec.Column.String()) { 887 defaultExpr := b.convertDefaultExpression(inScope, ddl.DefaultSpec.Value, c.Type, c.Nullable) 888 defSet := plan.NewAlterDefaultSet(table.Database(), table, ddl.DefaultSpec.Column.String(), defaultExpr) 889 outScope.node = b.modifySchemaTarget(inScope, defSet, table) 890 return 891 } 892 } 893 err := sql.ErrTableColumnNotFound.New(table.Name(), ddl.DefaultSpec.Column.String()) 894 b.handleErr(err) 895 return 896 case ast.DropStr: 897 outScope.node = plan.NewAlterDefaultDrop(table.Database(), table, ddl.DefaultSpec.Column.String()) 898 return 899 default: 900 err := sql.ErrUnsupportedFeature.New(ast.String(ddl)) 901 b.handleErr(err) 902 } 903 return 904 } 905 906 func (b *Builder) buildAlterCollationSpec(inScope *scope, ddl *ast.DDL, table *plan.ResolvedTable) (outScope *scope) { 907 outScope = inScope 908 var charSetStr *string 909 var collationStr *string 910 if len(ddl.AlterCollationSpec.CharacterSet) > 0 { 911 charSetStr = &ddl.AlterCollationSpec.CharacterSet 912 } 913 if len(ddl.AlterCollationSpec.Collation) > 0 { 914 collationStr = &ddl.AlterCollationSpec.Collation 915 } 916 collation, err := sql.ParseCollation(charSetStr, collationStr, false) 917 if err != nil { 918 b.handleErr(err) 919 } 920 outScope.node = plan.NewAlterTableCollationResolved(table, collation) 921 return 922 } 923 924 func (b *Builder) buildDefaultExpression(inScope *scope, defaultExpr ast.Expr) *sql.ColumnDefaultValue { 925 if defaultExpr == nil { 926 return nil 927 } 928 parsedExpr := b.buildScalar(inScope, defaultExpr) 929 930 // Function expressions must be enclosed in parentheses (except for current_timestamp() and now()) 931 _, isParenthesized := defaultExpr.(*ast.ParenExpr) 932 isLiteral := !isParenthesized 933 934 // A literal will never have children, thus we can also check for that. 935 if unaryExpr, is := defaultExpr.(*ast.UnaryExpr); is { 936 if _, lit := unaryExpr.Expr.(*ast.SQLVal); lit { 937 isLiteral = true 938 } 939 } else if !isParenthesized { 940 if f, ok := parsedExpr.(*expression.UnresolvedFunction); ok { 941 // Datetime and Timestamp columns allow now and current_timestamp to not be enclosed in parens, 942 // but they still need to be treated as function expressions 943 switch strings.ToLower(f.Name()) { 944 case "now", "current_timestamp", "localtime", "localtimestamp": 945 isLiteral = false 946 default: 947 err := sql.ErrSyntaxError.New("column default function expressions must be enclosed in parentheses") 948 b.handleErr(err) 949 } 950 } 951 } 952 953 return ExpressionToColumnDefaultValue(parsedExpr, isLiteral, isParenthesized) 954 } 955 956 // ExpressionToColumnDefaultValue takes in an Expression and returns the equivalent ColumnDefaultValue if the expression 957 // is valid for a default value. If the expression represents a literal (and not an expression that returns a literal, so "5" 958 // rather than "(5)"), then the parameter "isLiteral" should be true. 959 func ExpressionToColumnDefaultValue(inputExpr sql.Expression, isLiteral, isParenthesized bool) *sql.ColumnDefaultValue { 960 return &sql.ColumnDefaultValue{ 961 Expr: inputExpr, 962 OutType: nil, 963 Literal: isLiteral, 964 ReturnNil: true, 965 Parenthesized: isParenthesized, 966 } 967 } 968 969 func (b *Builder) buildExternalCreateIndex(inScope *scope, ddl *ast.DDL) (outScope *scope) { 970 config := make(map[string]string) 971 for _, option := range ddl.IndexSpec.Options { 972 if option.Using != "" { 973 config[option.Name] = option.Using 974 } else { 975 config[option.Name] = string(option.Value.Val) 976 } 977 } 978 979 dbName := strings.ToLower(ddl.Table.Qualifier.String()) 980 tblName := strings.ToLower(ddl.Table.Name.String()) 981 var ok bool 982 outScope, ok = b.buildTablescan(inScope, dbName, tblName, nil) 983 if !ok { 984 b.handleErr(sql.ErrTableNotFound.New(tblName)) 985 } 986 table, ok := outScope.node.(*plan.ResolvedTable) 987 if !ok { 988 err := fmt.Errorf("expected resolved table: %s", tblName) 989 b.handleErr(err) 990 } 991 992 tableId := outScope.tables[tblName] 993 cols := make([]sql.Expression, len(ddl.IndexSpec.Columns)) 994 for i, col := range ddl.IndexSpec.Columns { 995 colName := strings.ToLower(col.Column.String()) 996 c, ok := inScope.resolveColumn(dbName, tblName, colName, true, false) 997 if !ok { 998 b.handleErr(sql.ErrColumnNotFound.New(colName)) 999 } 1000 cols[i] = expression.NewGetFieldWithTable(int(c.id), int(tableId), c.typ, c.db, c.table, c.col, c.nullable) 1001 } 1002 1003 createIndex := plan.NewCreateIndex( 1004 ddl.IndexSpec.ToName.String(), 1005 table, 1006 cols, 1007 ddl.IndexSpec.Using.Lowered(), 1008 config, 1009 ) 1010 createIndex.Catalog = b.cat 1011 outScope.node = createIndex 1012 return 1013 } 1014 1015 // validateOnUpdateExprs ensures that the Time functions used for OnUpdate for columns is correct 1016 func validateOnUpdateExprs(col *sql.Column) error { 1017 if col.OnUpdate == nil { 1018 return nil 1019 } 1020 if !(types.IsDatetimeType(col.Type) || types.IsTimestampType(col.Type)) { 1021 return sql.ErrInvalidOnUpdate.New(col.Name) 1022 } 1023 now, ok := col.OnUpdate.Expr.(*function.Now) 1024 if !ok { 1025 return nil 1026 } 1027 children := now.Children() 1028 if len(children) == 0 { 1029 return nil 1030 } 1031 lit, isLit := children[0].(*expression.Literal) 1032 if !isLit { 1033 return nil 1034 } 1035 val, err := lit.Eval(nil, nil) 1036 if err != nil { 1037 return err 1038 } 1039 prec, ok := types.CoalesceInt(val) 1040 if !ok { 1041 return sql.ErrInvalidOnUpdate.New(col.Name) 1042 } 1043 if prec != 0 { 1044 return sql.ErrInvalidOnUpdate.New(col.Name) 1045 } 1046 return nil 1047 } 1048 1049 // TableSpecToSchema creates a sql.Schema from a parsed TableSpec and returns the parsed primary key schema, collation ID, and table comment. 1050 func (b *Builder) tableSpecToSchema(inScope, outScope *scope, db sql.Database, tableName string, tableSpec *ast.TableSpec, forceInvalidCollation bool) (sql.PrimaryKeySchema, sql.CollationID, string) { 1051 // todo: somewhere downstream updates an ALTER MODIY column's type collation 1052 // to match the underlying. That only happens if the type stays unspecified. 1053 tableCollation := sql.Collation_Unspecified 1054 tableComment := "" 1055 if !forceInvalidCollation { 1056 tableCollation = sql.Collation_Default 1057 if cdb, _ := db.(sql.CollatedDatabase); cdb != nil { 1058 tableCollation = cdb.GetCollation(b.ctx) 1059 } 1060 if len(tableSpec.Options) > 0 { 1061 charsetSubmatches := tableCharsetOptionRegex.FindStringSubmatch(tableSpec.Options) 1062 collationSubmatches := tableCollationOptionRegex.FindStringSubmatch(tableSpec.Options) 1063 commentSubmatches := tableCommentOptionRegex.FindStringSubmatch(tableSpec.Options) 1064 if len(charsetSubmatches) == 5 && len(collationSubmatches) == 5 { 1065 var err error 1066 tableCollation, err = sql.ParseCollation(&charsetSubmatches[4], &collationSubmatches[4], false) 1067 if err != nil { 1068 return sql.PrimaryKeySchema{}, sql.Collation_Unspecified, "" 1069 } 1070 } else if len(charsetSubmatches) == 5 { 1071 charset, err := sql.ParseCharacterSet(charsetSubmatches[4]) 1072 if err != nil { 1073 return sql.PrimaryKeySchema{}, sql.Collation_Unspecified, "" 1074 } 1075 tableCollation = charset.DefaultCollation() 1076 } else if len(collationSubmatches) == 5 { 1077 var err error 1078 tableCollation, err = sql.ParseCollation(nil, &collationSubmatches[4], false) 1079 if err != nil { 1080 return sql.PrimaryKeySchema{}, sql.Collation_Unspecified, "" 1081 } 1082 } 1083 if len(commentSubmatches) == 5 { 1084 tableComment = commentSubmatches[4] 1085 } 1086 } 1087 } 1088 1089 tabId := outScope.addTable(tableName) 1090 1091 defaults := make([]ast.Expr, len(tableSpec.Columns)) 1092 generated := make([]ast.Expr, len(tableSpec.Columns)) 1093 updates := make([]ast.Expr, len(tableSpec.Columns)) 1094 var schema sql.Schema 1095 for i, cd := range tableSpec.Columns { 1096 if cd.Type.ResolvedType == nil { 1097 sqlType := cd.Type.SQLType() 1098 // Use the table's collation if no character or collation was specified for the table 1099 if len(cd.Type.Charset) == 0 && len(cd.Type.Collate) == 0 { 1100 if tableCollation != sql.Collation_Unspecified && !types.IsBinary(sqlType) { 1101 cd.Type.Collate = tableCollation.Name() 1102 } 1103 } 1104 } 1105 defaults[i] = cd.Type.Default 1106 generated[i] = cd.Type.GeneratedExpr 1107 updates[i] = cd.Type.OnUpdate 1108 1109 column := b.columnDefinitionToColumn(inScope, cd, tableSpec.Indexes) 1110 column.DatabaseSource = db.Name() 1111 1112 if column.PrimaryKey && bool(cd.Type.Null) { 1113 b.handleErr(ErrPrimaryKeyOnNullField.New()) 1114 } 1115 1116 schema = append(schema, column) 1117 outScope.newColumn(scopeColumn{ 1118 tableId: tabId, 1119 table: tableName, 1120 db: db.Name(), 1121 col: strings.ToLower(column.Name), 1122 typ: column.Type, 1123 nullable: column.Nullable, 1124 }) 1125 } 1126 1127 for i, def := range defaults { 1128 schema[i].Default = b.convertDefaultExpression(outScope, def, schema[i].Type, schema[i].Nullable) 1129 if def != nil && generated[i] != nil { 1130 b.handleErr(sql.ErrGeneratedColumnWithDefault.New()) 1131 return sql.PrimaryKeySchema{}, sql.Collation_Unspecified, "" 1132 } 1133 } 1134 1135 for i, gen := range generated { 1136 if gen != nil { 1137 virtual := !bool(tableSpec.Columns[i].Type.Stored) 1138 schema[i].Generated = b.convertDefaultExpression(outScope, gen, schema[i].Type, schema[i].Nullable) 1139 // generated expressions are always parenthesized, but we don't record this in the parser 1140 schema[i].Generated.Parenthesized = true 1141 schema[i].Generated.Literal = false 1142 schema[i].Virtual = virtual 1143 } 1144 } 1145 1146 for i, onUpdateExpr := range updates { 1147 schema[i].OnUpdate = b.convertDefaultExpression(outScope, onUpdateExpr, schema[i].Type, schema[i].Nullable) 1148 err := validateOnUpdateExprs(schema[i]) 1149 if err != nil { 1150 b.handleErr(err) 1151 } 1152 } 1153 1154 pkSch := sql.NewPrimaryKeySchema(schema, getPkOrdinals(tableSpec)...) 1155 return pkSch, tableCollation, tableComment 1156 } 1157 1158 // jsonTableSpecToSchemaHelper creates a sql.Schema from a parsed TableSpec 1159 func (b *Builder) jsonTableSpecToSchemaHelper(jsonTableSpec *ast.JSONTableSpec, sch sql.Schema) { 1160 for _, cd := range jsonTableSpec.Columns { 1161 if cd.Spec != nil { 1162 b.jsonTableSpecToSchemaHelper(cd.Spec, sch) 1163 continue 1164 } 1165 typ, err := types.ColumnTypeToType(&cd.Type) 1166 if err != nil { 1167 b.handleErr(err) 1168 } 1169 col := &sql.Column{ 1170 Type: typ, 1171 Name: cd.Name.String(), 1172 AutoIncrement: bool(cd.Type.Autoincrement), 1173 } 1174 sch = append(sch, col) 1175 continue 1176 } 1177 } 1178 1179 // jsonTableSpecToSchema creates a sql.Schema from a parsed TableSpec 1180 func (b *Builder) jsonTableSpecToSchema(tableSpec *ast.JSONTableSpec) sql.Schema { 1181 var sch sql.Schema 1182 b.jsonTableSpecToSchemaHelper(tableSpec, sch) 1183 return sch 1184 } 1185 1186 // These constants aren't exported from vitess for some reason. This could be removed if we changed this. 1187 const ( 1188 colKeyNone ast.ColumnKeyOption = iota 1189 colKeyPrimary 1190 colKeySpatialKey 1191 colKeyUnique 1192 colKeyUniqueKey 1193 colKey 1194 colKeyFulltextKey 1195 ) 1196 1197 func getPkOrdinals(ts *ast.TableSpec) []int { 1198 for _, idxDef := range ts.Indexes { 1199 if idxDef.Info.Primary { 1200 1201 pkOrdinals := make([]int, 0) 1202 colIdx := make(map[string]int) 1203 for i := 0; i < len(ts.Columns); i++ { 1204 colIdx[ts.Columns[i].Name.Lowered()] = i 1205 } 1206 1207 for _, i := range idxDef.Columns { 1208 pkOrdinals = append(pkOrdinals, colIdx[i.Column.Lowered()]) 1209 } 1210 1211 return pkOrdinals 1212 } 1213 } 1214 1215 // no primary key expression, check for inline PK column 1216 for i, col := range ts.Columns { 1217 if col.Type.KeyOpt == colKeyPrimary { 1218 return []int{i} 1219 } 1220 } 1221 1222 return []int{} 1223 } 1224 1225 // columnDefinitionToColumn returns the sql.Column for the column definition given, as part of a create table 1226 // statement. Defaults and generated expressions must be handled separately. 1227 func (b *Builder) columnDefinitionToColumn(inScope *scope, cd *ast.ColumnDefinition, indexes []*ast.IndexDefinition) *sql.Column { 1228 internalTyp, err := types.ColumnTypeToType(&cd.Type) 1229 if err != nil { 1230 b.handleErr(err) 1231 } 1232 1233 // Primary key info can either be specified in the column's type info (for in-line declarations), or in a slice of 1234 // indexes attached to the table def. We have to check both places to find if a column is part of the primary key 1235 isPkey := cd.Type.KeyOpt == colKeyPrimary 1236 1237 if !isPkey { 1238 OuterLoop: 1239 for _, index := range indexes { 1240 if index.Info.Primary { 1241 for _, indexCol := range index.Columns { 1242 if indexCol.Column.Equal(cd.Name) { 1243 isPkey = true 1244 break OuterLoop 1245 } 1246 } 1247 } 1248 } 1249 } 1250 1251 var comment string 1252 if cd.Type.Comment != nil && cd.Type.Comment.Type == ast.StrVal { 1253 comment = string(cd.Type.Comment.Val) 1254 } 1255 1256 nullable := !isPkey && !bool(cd.Type.NotNull) 1257 extra := "" 1258 1259 if cd.Type.Autoincrement { 1260 extra = "auto_increment" 1261 } 1262 1263 if cd.Type.SRID != nil { 1264 sridVal, err := strconv.ParseInt(string(cd.Type.SRID.Val), 10, 32) 1265 if err != nil { 1266 b.handleErr(err) 1267 } 1268 1269 if err = types.ValidateSRID(int(sridVal), ""); err != nil { 1270 b.handleErr(err) 1271 } 1272 if s, ok := internalTyp.(sql.SpatialColumnType); ok { 1273 internalTyp = s.SetSRID(uint32(sridVal)) 1274 } else { 1275 b.handleErr(sql.ErrInvalidType.New(fmt.Sprintf("cannot define SRID for %s", internalTyp))) 1276 } 1277 } 1278 1279 return &sql.Column{ 1280 Name: cd.Name.String(), 1281 Type: internalTyp, 1282 AutoIncrement: bool(cd.Type.Autoincrement), 1283 Nullable: nullable, 1284 PrimaryKey: isPkey, 1285 Comment: comment, 1286 Extra: extra, 1287 } 1288 } 1289 1290 func (b *Builder) modifySchemaTarget(inScope *scope, n sql.SchemaTarget, rt *plan.ResolvedTable) sql.Node { 1291 targSchema := b.resolveSchemaDefaults(inScope, rt.Schema()) 1292 ret, err := n.WithTargetSchema(targSchema) 1293 if err != nil { 1294 b.handleErr(err) 1295 } 1296 return ret 1297 } 1298 1299 func (b *Builder) resolveSchemaDefaults(inScope *scope, schema sql.Schema) sql.Schema { 1300 if len(schema) == 0 { 1301 return nil 1302 } 1303 if len(inScope.cols) < len(schema) { 1304 // alter statements only add definitions for modified columns 1305 // backfill rest of columns 1306 resolveScope := inScope.replace() 1307 for _, col := range schema { 1308 resolveScope.newColumn(scopeColumn{ 1309 db: col.DatabaseSource, 1310 table: strings.ToLower(col.Source), 1311 col: strings.ToLower(col.Name), 1312 typ: col.Type, 1313 nullable: col.Nullable, 1314 }) 1315 } 1316 inScope = resolveScope 1317 } 1318 1319 newSch := schema.Copy() 1320 for _, part := range partitionTableColumns(newSch) { 1321 start := part[0] 1322 end := part[1] 1323 subScope := inScope.replace() 1324 for i := start; i < end; i++ { 1325 subScope.addColumn(inScope.cols[i]) 1326 } 1327 for _, col := range newSch[start:end] { 1328 col.Default = b.resolveColumnDefaultExpression(subScope, col, col.Default) 1329 col.Generated = b.resolveColumnDefaultExpression(subScope, col, col.Generated) 1330 col.OnUpdate = b.resolveColumnDefaultExpression(subScope, col, col.OnUpdate) 1331 } 1332 } 1333 return newSch 1334 } 1335 1336 // partitionTableColumns splits a sql.Schema into a list 1337 // of [2]int{start,end} ranges that each partition the tables 1338 // included in the schema. 1339 func partitionTableColumns(sch sql.Schema) [][2]int { 1340 var ret [][2]int 1341 var i int = 1 1342 var prevI int = 0 1343 for i < len(sch) { 1344 if strings.EqualFold(sch[i-1].Source, sch[i].Source) && 1345 strings.EqualFold(sch[i-1].DatabaseSource, sch[i].DatabaseSource) { 1346 i++ 1347 continue 1348 } 1349 ret = append(ret, [2]int{prevI, i}) 1350 prevI = i 1351 i++ 1352 } 1353 ret = append(ret, [2]int{prevI, i}) 1354 return ret 1355 } 1356 1357 func (b *Builder) resolveColumnDefaultExpression(inScope *scope, columnDef *sql.Column, colDefault *sql.ColumnDefaultValue) *sql.ColumnDefaultValue { 1358 if colDefault == nil || colDefault.Expr == nil { 1359 return colDefault 1360 } 1361 1362 def, ok := colDefault.Expr.(*sql.UnresolvedColumnDefault) 1363 if !ok { 1364 // no resolution work to be done, return the original value 1365 return colDefault 1366 } 1367 1368 // Empty string is a special case, it means the default value is the empty string 1369 // TODO: why isn't this serialized as '' 1370 if def.String() == "" { 1371 return b.convertDefaultExpression(inScope, &ast.SQLVal{Val: []byte{}, Type: ast.StrVal}, columnDef.Type, columnDef.Nullable) 1372 } 1373 1374 parsed, err := ast.Parse(fmt.Sprintf("SELECT %s", def)) 1375 if err != nil { 1376 err := fmt.Errorf("%w: %s", sql.ErrInvalidColumnDefaultValue.New(def), err) 1377 b.handleErr(err) 1378 } 1379 1380 selectStmt, ok := parsed.(*ast.Select) 1381 if !ok || len(selectStmt.SelectExprs) != 1 { 1382 err := sql.ErrInvalidColumnDefaultValue.New(def) 1383 b.handleErr(err) 1384 } 1385 1386 expr := selectStmt.SelectExprs[0] 1387 ae, ok := expr.(*ast.AliasedExpr) 1388 if !ok { 1389 err := sql.ErrInvalidColumnDefaultValue.New(def) 1390 b.handleErr(err) 1391 } 1392 1393 return b.convertDefaultExpression(inScope, ae.Expr, columnDef.Type, columnDef.Nullable) 1394 } 1395 1396 func (b *Builder) convertDefaultExpression(inScope *scope, defaultExpr ast.Expr, typ sql.Type, nullable bool) *sql.ColumnDefaultValue { 1397 if defaultExpr == nil { 1398 return nil 1399 } 1400 resExpr := b.buildScalar(inScope, defaultExpr) 1401 1402 // Function expressions must be enclosed in parentheses (except for current_timestamp() and now()) 1403 _, isParenthesized := defaultExpr.(*ast.ParenExpr) 1404 isLiteral := !isParenthesized 1405 1406 // A literal will never have children, thus we can also check for that. 1407 if unaryExpr, is := defaultExpr.(*ast.UnaryExpr); is { 1408 if _, lit := unaryExpr.Expr.(*ast.SQLVal); lit { 1409 isLiteral = true 1410 } 1411 } else if !isParenthesized { 1412 if _, ok := resExpr.(sql.FunctionExpression); ok { 1413 switch resExpr.(type) { 1414 case *function.Now: 1415 // Datetime and Timestamp columns allow now and current_timestamp to not be enclosed in parens, 1416 // but they still need to be treated as function expressions 1417 isLiteral = false 1418 default: 1419 // All other functions must *always* be enclosed in parens 1420 err := sql.ErrSyntaxError.New("column default function expressions must be enclosed in parentheses") 1421 b.handleErr(err) 1422 } 1423 } 1424 } 1425 1426 // TODO: fix the vitess parser so that it parses negative numbers as numbers and not negation of an expression 1427 if unaryMinusExpr, ok := resExpr.(*expression.UnaryMinus); ok { 1428 if literalExpr, ok := unaryMinusExpr.Child.(*expression.Literal); ok { 1429 switch val := literalExpr.Value().(type) { 1430 case float32: 1431 resExpr = expression.NewLiteral(-val, types.Float32) 1432 isLiteral = true 1433 case float64: 1434 resExpr = expression.NewLiteral(-val, types.Float64) 1435 isLiteral = true 1436 } 1437 } 1438 } 1439 1440 return &sql.ColumnDefaultValue{ 1441 Expr: resExpr, 1442 OutType: typ, 1443 Literal: isLiteral, 1444 ReturnNil: nullable, 1445 Parenthesized: isParenthesized, 1446 } 1447 } 1448 1449 func (b *Builder) buildDBDDL(inScope *scope, c *ast.DBDDL) (outScope *scope) { 1450 outScope = inScope.push() 1451 switch strings.ToLower(c.Action) { 1452 case ast.CreateStr: 1453 var charsetStr *string 1454 var collationStr *string 1455 for _, cc := range c.CharsetCollate { 1456 ccType := strings.ToLower(cc.Type) 1457 if ccType == "character set" { 1458 val := cc.Value 1459 charsetStr = &val 1460 } else if ccType == "collate" { 1461 val := cc.Value 1462 collationStr = &val 1463 } else if b.ctx != nil && b.ctx.Session != nil { 1464 b.ctx.Session.Warn(&sql.Warning{ 1465 Level: "Warning", 1466 Code: mysql.ERNotSupportedYet, 1467 Message: "Setting CHARACTER SET, COLLATION and ENCRYPTION are not supported yet", 1468 }) 1469 } 1470 } 1471 collation, err := sql.ParseCollation(charsetStr, collationStr, false) 1472 if err != nil { 1473 b.handleErr(err) 1474 } 1475 createDb := plan.NewCreateDatabase(c.DBName, c.IfNotExists, collation) 1476 createDb.Catalog = b.cat 1477 outScope.node = createDb 1478 case ast.DropStr: 1479 dropDb := plan.NewDropDatabase(c.DBName, c.IfExists) 1480 dropDb.Catalog = b.cat 1481 outScope.node = dropDb 1482 case ast.AlterStr: 1483 if len(c.CharsetCollate) == 0 { 1484 if len(c.DBName) > 0 { 1485 err := sql.ErrSyntaxError.New(fmt.Sprintf("alter database %s", c.DBName)) 1486 b.handleErr(err) 1487 } else { 1488 err := sql.ErrSyntaxError.New("alter database") 1489 b.handleErr(err) 1490 } 1491 } 1492 1493 var charsetStr *string 1494 var collationStr *string 1495 for _, cc := range c.CharsetCollate { 1496 ccType := strings.ToLower(cc.Type) 1497 if ccType == "character set" { 1498 val := cc.Value 1499 charsetStr = &val 1500 } else if ccType == "collate" { 1501 val := cc.Value 1502 collationStr = &val 1503 } 1504 } 1505 collation, err := sql.ParseCollation(charsetStr, collationStr, false) 1506 if err != nil { 1507 b.handleErr(err) 1508 } 1509 alterDb := plan.NewAlterDatabase(c.DBName, collation) 1510 alterDb.Catalog = b.cat 1511 outScope.node = alterDb 1512 default: 1513 err := sql.ErrUnsupportedSyntax.New(ast.String(c)) 1514 b.handleErr(err) 1515 } 1516 return outScope 1517 } 1518 1519 // ExtendedTypeTag is primarily used by ParseColumnTypeString when parsing strings representing extended types 1520 const ExtendedTypeTag = "extended_" 1521 1522 func ParseColumnTypeString(columnType string) (sql.Type, error) { 1523 if strings.HasPrefix(columnType, ExtendedTypeTag) { 1524 columnType = columnType[len(ExtendedTypeTag):] 1525 // If the pipe character "|" is present, then we ignore all information after it (including the pipe), as it 1526 // represents a comment 1527 if pipeIdx := strings.Index(columnType, "|"); pipeIdx != -1 { 1528 columnType = columnType[:pipeIdx] 1529 } 1530 c, err := types.DeserializeTypeFromString(columnType) 1531 if err != nil { 1532 return nil, err 1533 } 1534 return c, nil 1535 } 1536 parsed, err := ast.Parse(fmt.Sprintf("create table t(a %s)", columnType)) 1537 if err != nil { 1538 return nil, err 1539 } 1540 ddl, ok := parsed.(*ast.DDL) 1541 if !ok { 1542 return nil, fmt.Errorf("failed to parse type info for column: %s", columnType) 1543 } 1544 parsedTyp := ddl.TableSpec.Columns[0].Type 1545 typ, err := types.ColumnTypeToType(&parsedTyp) 1546 if err != nil { 1547 return nil, err 1548 } 1549 if parsedTyp.SRID != nil { 1550 sridVal, err := strconv.ParseInt(string(parsedTyp.SRID.Val), 10, 32) 1551 if err != nil { 1552 return nil, err 1553 } 1554 1555 if err = types.ValidateSRID(int(sridVal), ""); err != nil { 1556 return nil, err 1557 } 1558 if s, ok := typ.(sql.SpatialColumnType); ok { 1559 typ = s.SetSRID(uint32(sridVal)) 1560 } else { 1561 return nil, sql.ErrInvalidType.New(fmt.Sprintf("cannot define SRID for %s", typ)) 1562 } 1563 } 1564 return typ, nil 1565 } 1566 1567 var _ sql.Database = dummyDb{} 1568 1569 type dummyDb struct { 1570 name string 1571 } 1572 1573 func (d dummyDb) Name() string { return d.name } 1574 func (d dummyDb) Tables() map[string]sql.Table { return nil } 1575 func (d dummyDb) GetTableInsensitive(ctx *sql.Context, tblName string) (sql.Table, bool, error) { 1576 return nil, false, nil 1577 } 1578 func (d dummyDb) GetTableNames(ctx *sql.Context) ([]string, error) { 1579 return nil, nil 1580 }