github.com/dolthub/go-mysql-server@v0.18.0/sql/planbuilder/ddl.go (about)

     1  // Copyright 2023 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package planbuilder
    16  
    17  import (
    18  	"fmt"
    19  	"strconv"
    20  	"strings"
    21  
    22  	"github.com/dolthub/vitess/go/mysql"
    23  	ast "github.com/dolthub/vitess/go/vt/sqlparser"
    24  
    25  	"github.com/dolthub/go-mysql-server/sql"
    26  	"github.com/dolthub/go-mysql-server/sql/expression"
    27  	"github.com/dolthub/go-mysql-server/sql/expression/function"
    28  	"github.com/dolthub/go-mysql-server/sql/mysql_db"
    29  	"github.com/dolthub/go-mysql-server/sql/plan"
    30  	"github.com/dolthub/go-mysql-server/sql/types"
    31  )
    32  
    33  func (b *Builder) resolveDb(name string) sql.Database {
    34  	if name == "" {
    35  		err := sql.ErrNoDatabaseSelected.New()
    36  		b.handleErr(err)
    37  	}
    38  	database, err := b.cat.Database(b.ctx, name)
    39  	if err != nil {
    40  		b.handleErr(err)
    41  	}
    42  
    43  	// todo show tables as of expects privileged
    44  	//if privilegedDatabase, ok := database.(mysql_db.PrivilegedDatabase); ok {
    45  	//	database = privilegedDatabase.Unwrap()
    46  	//}
    47  	return database
    48  }
    49  
    50  // buildAlterTable converts AlterTable AST nodes. If there is a single clause in the statement, it is returned as
    51  // the appropriate node type. Otherwise, a plan.Block is returned with children representing all the various clauses.
    52  // Our validation rules for what counts as a legal set of alter clauses differs from mysql's here. MySQL seems to apply
    53  // some form of precedence rules to the clauses in an ALTER TABLE so that e.g. DROP COLUMN always happens before other
    54  // kinds of statements. So in MySQL, statements like `ALTER TABLE t ADD KEY (a), DROP COLUMN a` fails, whereas our
    55  // analyzer happily produces a plan that adds an index and then drops that column. We do this in part for simplicity,
    56  // and also because we construct more than one node per clause in some cases and really want them executed in a
    57  // particular order in that case.
    58  func (b *Builder) buildAlterTable(inScope *scope, query string, c *ast.AlterTable) (outScope *scope) {
    59  	b.multiDDL = true
    60  	defer func() {
    61  		b.multiDDL = false
    62  	}()
    63  
    64  	statements := make([]sql.Node, 0, len(c.Statements))
    65  	for i := 0; i < len(c.Statements); i++ {
    66  		scopes := b.buildAlterTableClause(inScope, c.Statements[i])
    67  		for _, scope := range scopes {
    68  			statements = append(statements, scope.node)
    69  		}
    70  	}
    71  
    72  	if len(statements) == 1 {
    73  		outScope = inScope.push()
    74  		outScope.node = statements[0]
    75  		return outScope
    76  	}
    77  
    78  	outScope = inScope.push()
    79  	outScope.node = plan.NewBlock(statements)
    80  	return
    81  }
    82  
    83  func (b *Builder) buildDDL(inScope *scope, query string, c *ast.DDL) (outScope *scope) {
    84  	outScope = inScope.push()
    85  	switch strings.ToLower(c.Action) {
    86  	case ast.CreateStr:
    87  		if c.TriggerSpec != nil {
    88  			return b.buildCreateTrigger(inScope, query, c)
    89  		}
    90  		if c.ProcedureSpec != nil {
    91  			return b.buildCreateProcedure(inScope, query, c)
    92  		}
    93  		if c.EventSpec != nil {
    94  			return b.buildCreateEvent(inScope, query, c)
    95  		}
    96  		if c.ViewSpec != nil {
    97  			return b.buildCreateView(inScope, query, c)
    98  		}
    99  		return b.buildCreateTable(inScope, c)
   100  	case ast.DropStr:
   101  		// get database
   102  		if c.TriggerSpec != nil {
   103  			dbName := c.TriggerSpec.TrigName.Qualifier.String()
   104  			if dbName == "" {
   105  				dbName = b.ctx.GetCurrentDatabase()
   106  			}
   107  			trigName := c.TriggerSpec.TrigName.Name.String()
   108  			outScope.node = plan.NewDropTrigger(b.resolveDb(dbName), trigName, c.IfExists)
   109  			return
   110  		}
   111  		if c.ProcedureSpec != nil {
   112  			dbName := c.ProcedureSpec.ProcName.Qualifier.String()
   113  			if dbName == "" {
   114  				dbName = b.ctx.GetCurrentDatabase()
   115  			}
   116  			procName := c.ProcedureSpec.ProcName.Name.String()
   117  			outScope.node = plan.NewDropProcedure(b.resolveDb(dbName), procName, c.IfExists)
   118  			return
   119  		}
   120  		if c.EventSpec != nil {
   121  			dbName := c.EventSpec.EventName.Qualifier.String()
   122  			if dbName == "" {
   123  				dbName = b.ctx.GetCurrentDatabase()
   124  			}
   125  			eventName := c.EventSpec.EventName.Name.String()
   126  			outScope.node = plan.NewDropEvent(b.resolveDb(dbName), eventName, c.IfExists)
   127  			return
   128  		}
   129  		if len(c.FromViews) != 0 {
   130  			plans := make([]sql.Node, len(c.FromViews))
   131  			for i, v := range c.FromViews {
   132  				plans[i] = plan.NewSingleDropView(b.currentDb(), v.Name.String())
   133  			}
   134  			outScope.node = plan.NewDropView(plans, c.IfExists)
   135  			return
   136  		}
   137  		return b.buildDropTable(inScope, c)
   138  	case ast.AlterStr:
   139  		if c.EventSpec != nil {
   140  			return b.buildAlterEvent(inScope, query, c)
   141  		} else if !c.User.IsEmpty() {
   142  			return b.buildAlterUser(inScope, query, c)
   143  		}
   144  		b.handleErr(sql.ErrUnsupportedFeature.New(ast.String(c)))
   145  	case ast.RenameStr:
   146  		return b.buildRenameTable(inScope, c)
   147  	case ast.TruncateStr:
   148  		return b.buildTruncateTable(inScope, c)
   149  	default:
   150  		b.handleErr(sql.ErrUnsupportedSyntax.New(ast.String(c)))
   151  	}
   152  	return
   153  }
   154  
   155  func (b *Builder) buildDropTable(inScope *scope, c *ast.DDL) (outScope *scope) {
   156  	outScope = inScope.push()
   157  	var dropTables []sql.Node
   158  	dbName := c.FromTables[0].Qualifier.String()
   159  	if dbName == "" {
   160  		dbName = b.currentDb().Name()
   161  	}
   162  	for _, t := range c.FromTables {
   163  		if t.Qualifier.String() != "" && t.Qualifier.String() != dbName {
   164  			err := sql.ErrUnsupportedFeature.New("dropping tables on multiple databases in the same statement")
   165  			b.handleErr(err)
   166  		}
   167  		tableName := strings.ToLower(t.Name.String())
   168  		if c.IfExists {
   169  			_, _, err := b.cat.Table(b.ctx, dbName, tableName)
   170  			if sql.ErrTableNotFound.Is(err) && b.ctx != nil && b.ctx.Session != nil {
   171  				b.ctx.Session.Warn(&sql.Warning{
   172  					Level:   "Note",
   173  					Code:    mysql.ERBadTable,
   174  					Message: fmt.Sprintf("Unknown table '%s'", tableName),
   175  				})
   176  				continue
   177  			} else if err != nil {
   178  				b.handleErr(err)
   179  			}
   180  		}
   181  
   182  		tableScope, ok := b.buildResolvedTable(inScope, dbName, tableName, nil)
   183  		if ok {
   184  			dropTables = append(dropTables, tableScope.node)
   185  		} else if !c.IfExists {
   186  			err := sql.ErrTableNotFound.New(tableName)
   187  			b.handleErr(err)
   188  		}
   189  	}
   190  
   191  	outScope.node = plan.NewDropTable(dropTables, c.IfExists)
   192  	return
   193  }
   194  
   195  func (b *Builder) buildTruncateTable(inScope *scope, c *ast.DDL) (outScope *scope) {
   196  	outScope = inScope.push()
   197  	dbName := c.Table.Qualifier.String()
   198  	tabName := c.Table.Name.String()
   199  	tableScope, ok := b.buildResolvedTable(inScope, dbName, tabName, nil)
   200  	if !ok {
   201  		b.handleErr(sql.ErrTableNotFound.New(tabName))
   202  	}
   203  	outScope.node = plan.NewTruncate(
   204  		c.Table.Qualifier.String(),
   205  		tableScope.node,
   206  	)
   207  	return
   208  }
   209  
   210  func (b *Builder) buildCreateTable(inScope *scope, c *ast.DDL) (outScope *scope) {
   211  	outScope = inScope.push()
   212  	if c.OptLike != nil {
   213  		return b.buildCreateTableLike(inScope, c)
   214  	}
   215  
   216  	qualifier := c.Table.Qualifier.String()
   217  	if qualifier == "" {
   218  		qualifier = b.ctx.GetCurrentDatabase()
   219  	}
   220  	database := b.resolveDb(qualifier)
   221  
   222  	// In the case that no table spec is given but a SELECT Statement return the CREATE TABLE node.
   223  	// if the table spec != nil it will get parsed below.
   224  	if c.TableSpec == nil && c.OptSelect != nil {
   225  		tableSpec := &plan.TableSpec{}
   226  
   227  		selectScope := b.buildSelectStmt(inScope, c.OptSelect.Select)
   228  
   229  		outScope.node = plan.NewCreateTableSelect(database, c.Table.Name.String(), selectScope.node, tableSpec, plan.IfNotExistsOption(c.IfNotExists), plan.TempTableOption(c.Temporary))
   230  		return outScope
   231  	}
   232  
   233  	idxDefs := b.buildIndexDefs(inScope, c.TableSpec)
   234  
   235  	schema, collation, comment := b.tableSpecToSchema(inScope, outScope, database, strings.ToLower(c.Table.Name.String()), c.TableSpec, false)
   236  	fkDefs, chDefs := b.buildConstraintsDefs(outScope, c.Table, c.TableSpec)
   237  
   238  	schema.Schema = assignColumnIndexesInSchema(schema.Schema)
   239  	chDefs = assignColumnIndexesInCheckDefs(chDefs, schema.Schema)
   240  
   241  	if privDb, ok := database.(mysql_db.PrivilegedDatabase); ok {
   242  		if sv, ok := privDb.Unwrap().(sql.SchemaValidator); ok {
   243  			if err := sv.ValidateSchema(schema.PhysicalSchema()); err != nil {
   244  				b.handleErr(err)
   245  			}
   246  		}
   247  	} else {
   248  		if sv, ok := database.(sql.SchemaValidator); ok {
   249  			if err := sv.ValidateSchema(schema.PhysicalSchema()); err != nil {
   250  				b.handleErr(err)
   251  			}
   252  		}
   253  	}
   254  
   255  	tableSpec := &plan.TableSpec{
   256  		Schema:    schema,
   257  		IdxDefs:   idxDefs,
   258  		FkDefs:    fkDefs,
   259  		ChDefs:    chDefs,
   260  		Collation: collation,
   261  		Comment:   comment,
   262  	}
   263  
   264  	if c.OptSelect != nil {
   265  		selectScope := b.buildSelectStmt(inScope, c.OptSelect.Select)
   266  		outScope.node = plan.NewCreateTableSelect(database, c.Table.Name.String(), selectScope.node, tableSpec, plan.IfNotExistsOption(c.IfNotExists), plan.TempTableOption(c.Temporary))
   267  	} else {
   268  		outScope.node = plan.NewCreateTable(
   269  			database, c.Table.Name.String(), plan.IfNotExistsOption(c.IfNotExists), plan.TempTableOption(c.Temporary), tableSpec)
   270  	}
   271  
   272  	return
   273  }
   274  
   275  func assignColumnIndexesInCheckDefs(defs []*sql.CheckConstraint, schema sql.Schema) []*sql.CheckConstraint {
   276  	newDefs := make([]*sql.CheckConstraint, len(defs))
   277  	for i, def := range defs {
   278  		newDefs[i] = def
   279  		newDefs[i].Expr = assignColumnIndexes(def.Expr, schema).(sql.Expression)
   280  	}
   281  	return newDefs
   282  }
   283  
   284  func assignColumnIndexesInSchema(schema sql.Schema) sql.Schema {
   285  	newSch := make(sql.Schema, len(schema))
   286  	for i, col := range schema {
   287  		newSch[i] = col
   288  		if col.Default != nil {
   289  			newSch[i].Default = assignColumnIndexes(col.Default, schema).(*sql.ColumnDefaultValue)
   290  		}
   291  		if col.Generated != nil {
   292  			newSch[i].Generated = assignColumnIndexes(col.Generated, schema).(*sql.ColumnDefaultValue)
   293  		}
   294  	}
   295  	return newSch
   296  }
   297  
   298  func (b *Builder) buildCreateTableLike(inScope *scope, ct *ast.DDL) *scope {
   299  	tableName := ct.OptLike.LikeTable.Name.String()
   300  	likeDbName := ct.OptLike.LikeTable.Qualifier.String()
   301  	if likeDbName == "" {
   302  		likeDbName = b.ctx.GetCurrentDatabase()
   303  	}
   304  	outScope, ok := b.buildTablescan(inScope, likeDbName, tableName, nil)
   305  	if !ok {
   306  		b.handleErr(sql.ErrTableNotFound.New(tableName))
   307  	}
   308  	likeTable, ok := outScope.node.(*plan.ResolvedTable)
   309  	if !ok {
   310  		err := fmt.Errorf("expected resolved table: %s", tableName)
   311  		b.handleErr(err)
   312  	}
   313  
   314  	newTableName := strings.ToLower(ct.Table.Name.String())
   315  	outScope.setTableAlias(newTableName)
   316  
   317  	var idxDefs []*plan.IndexDefinition
   318  	if indexableTable, ok := likeTable.Table.(sql.IndexAddressableTable); ok {
   319  		indexes, err := indexableTable.GetIndexes(b.ctx)
   320  		if err != nil {
   321  			b.handleErr(err)
   322  		}
   323  		for _, index := range indexes {
   324  			if index.IsGenerated() {
   325  				continue
   326  			}
   327  			constraint := sql.IndexConstraint_None
   328  			if index.IsUnique() {
   329  				if index.ID() == "PRIMARY" {
   330  					constraint = sql.IndexConstraint_Primary
   331  				} else {
   332  					constraint = sql.IndexConstraint_Unique
   333  				}
   334  			}
   335  
   336  			columns := make([]sql.IndexColumn, len(index.Expressions()))
   337  			for i, col := range index.Expressions() {
   338  				//TODO: find a better way to get only the column name if the table is present
   339  				col = strings.TrimPrefix(col, indexableTable.Name()+".")
   340  				columns[i] = sql.IndexColumn{
   341  					Name:   col,
   342  					Length: 0,
   343  				}
   344  			}
   345  			idxDefs = append(idxDefs, &plan.IndexDefinition{
   346  				IndexName:  index.ID(),
   347  				Using:      sql.IndexUsing_Default,
   348  				Constraint: constraint,
   349  				Columns:    columns,
   350  				Comment:    index.Comment(),
   351  			})
   352  		}
   353  	}
   354  	origSch := likeTable.Schema()
   355  	newSch := make(sql.Schema, len(origSch))
   356  	for i, col := range origSch {
   357  		tempCol := *col
   358  		tempCol.Source = newTableName
   359  		newSch[i] = &tempCol
   360  	}
   361  
   362  	var pkOrdinals []int
   363  	if pkTable, ok := likeTable.Table.(sql.PrimaryKeyTable); ok {
   364  		pkOrdinals = pkTable.PrimaryKeySchema().PkOrdinals
   365  	}
   366  
   367  	var checkDefs []*sql.CheckConstraint
   368  	if checksTable, ok := likeTable.Table.(sql.CheckTable); ok {
   369  		checks, err := checksTable.GetChecks(b.ctx)
   370  		if err != nil {
   371  			b.handleErr(err)
   372  		}
   373  
   374  		for _, check := range checks {
   375  			checkConstraint := b.buildCheckConstraint(outScope, &check)
   376  			if err != nil {
   377  				b.handleErr(err)
   378  			}
   379  
   380  			// Prevent a name collision between old and new checks.
   381  			// New check will be assigned a name during building.
   382  			checkConstraint.Name = ""
   383  			checkDefs = append(checkDefs, checkConstraint)
   384  		}
   385  	}
   386  
   387  	pkSchema := sql.NewPrimaryKeySchema(newSch, pkOrdinals...)
   388  	pkSchema.Schema = b.resolveSchemaDefaults(outScope, pkSchema.Schema)
   389  
   390  	tableSpec := &plan.TableSpec{
   391  		Schema:    pkSchema,
   392  		IdxDefs:   idxDefs,
   393  		ChDefs:    checkDefs,
   394  		Collation: likeTable.Collation(),
   395  		Comment:   likeTable.Comment(),
   396  	}
   397  
   398  	qualifier := ct.Table.Qualifier.String()
   399  	if qualifier == "" {
   400  		qualifier = b.ctx.GetCurrentDatabase()
   401  	}
   402  	database := b.resolveDb(qualifier)
   403  
   404  	outScope.node = plan.NewCreateTable(database, newTableName, plan.IfNotExistsOption(ct.IfNotExists), plan.TempTableOption(ct.Temporary), tableSpec)
   405  	return outScope
   406  }
   407  
   408  func (b *Builder) buildRenameTable(inScope *scope, ddl *ast.DDL) (outScope *scope) {
   409  	outScope = inScope
   410  	if len(ddl.FromTables) != len(ddl.ToTables) {
   411  		panic("Expected from tables and to tables of equal length")
   412  	}
   413  
   414  	var fromTables, toTables []string
   415  	for _, table := range ddl.FromTables {
   416  		fromTables = append(fromTables, table.Name.String())
   417  	}
   418  	for _, table := range ddl.ToTables {
   419  		toTables = append(toTables, table.Name.String())
   420  	}
   421  
   422  	outScope.node = plan.NewRenameTable(b.currentDb(), fromTables, toTables, b.multiDDL)
   423  	return
   424  }
   425  
   426  func (b *Builder) isUniqueColumn(tableSpec *ast.TableSpec, columnName string) bool {
   427  	for _, column := range tableSpec.Columns {
   428  		if column.Name.String() == columnName {
   429  			return column.Type.KeyOpt == colKeyUnique ||
   430  				column.Type.KeyOpt == colKeyUniqueKey
   431  		}
   432  	}
   433  	err := fmt.Errorf("unknown column name %s", columnName)
   434  	b.handleErr(err)
   435  	return false
   436  
   437  }
   438  
   439  func (b *Builder) buildAlterTableClause(inScope *scope, ddl *ast.DDL) []*scope {
   440  	outScopes := make([]*scope, 0, 1)
   441  
   442  	// RENAME a to b, c to d ..
   443  	if ddl.Action == ast.RenameStr {
   444  		outScopes = append(outScopes, b.buildRenameTable(inScope, ddl))
   445  	} else {
   446  		dbName := ddl.Table.Qualifier.String()
   447  		tableName := ddl.Table.Name.String()
   448  		var ok bool
   449  		tableScope, ok := b.buildResolvedTable(inScope, dbName, tableName, nil)
   450  		if !ok {
   451  			b.handleErr(sql.ErrTableNotFound.New(tableName))
   452  		}
   453  		rt, ok := tableScope.node.(*plan.ResolvedTable)
   454  		if !ok {
   455  			err := fmt.Errorf("expected resolved table: %s", tableName)
   456  			b.handleErr(err)
   457  		}
   458  
   459  		if ddl.ColumnAction != "" {
   460  			columnActionOutscope := b.buildAlterTableColumnAction(tableScope, ddl, rt)
   461  			outScopes = append(outScopes, columnActionOutscope)
   462  
   463  			if ddl.TableSpec != nil {
   464  				if len(ddl.TableSpec.Columns) != 1 {
   465  					err := sql.ErrUnsupportedFeature.New("unexpected number of columns in a single alter column clause")
   466  					b.handleErr(err)
   467  				}
   468  
   469  				column := ddl.TableSpec.Columns[0]
   470  				isUnique := b.isUniqueColumn(ddl.TableSpec, column.Name.String())
   471  				if isUnique {
   472  					createIndex := plan.NewAlterCreateIndex(
   473  						rt.Database(),
   474  						rt,
   475  						column.Name.String(),
   476  						sql.IndexUsing_BTree,
   477  						sql.IndexConstraint_Unique,
   478  						[]sql.IndexColumn{{Name: column.Name.String()}},
   479  						"",
   480  					)
   481  
   482  					createIndexScope := inScope.push()
   483  					createIndexScope.node = createIndex
   484  					outScopes = append(outScopes, createIndexScope)
   485  				}
   486  			}
   487  		}
   488  
   489  		if ddl.ConstraintAction != "" {
   490  			if len(ddl.TableSpec.Constraints) != 1 {
   491  				b.handleErr(sql.ErrUnsupportedFeature.New("unexpected number of constraints in a single alter constraint clause"))
   492  			}
   493  			outScopes = append(outScopes, b.buildAlterConstraint(tableScope, ddl, rt))
   494  		}
   495  
   496  		if ddl.IndexSpec != nil {
   497  			outScopes = append(outScopes, b.buildAlterIndex(tableScope, ddl, rt))
   498  		}
   499  
   500  		if ddl.AutoIncSpec != nil {
   501  			outScopes = append(outScopes, b.buildAlterAutoIncrement(tableScope, ddl, rt))
   502  		}
   503  
   504  		if ddl.DefaultSpec != nil {
   505  			outScopes = append(outScopes, b.buildAlterDefault(tableScope, ddl, rt))
   506  		}
   507  
   508  		if ddl.AlterCollationSpec != nil {
   509  			outScopes = append(outScopes, b.buildAlterCollationSpec(tableScope, ddl, rt))
   510  		}
   511  
   512  		for _, s := range outScopes {
   513  			if ts, ok := s.node.(sql.SchemaTarget); ok {
   514  				s.node = b.modifySchemaTarget(s, ts, rt)
   515  			}
   516  		}
   517  		pkt, _ := rt.Table.(sql.PrimaryKeyTable)
   518  		if pkt != nil {
   519  			for _, s := range outScopes {
   520  				if ts, ok := s.node.(sql.PrimaryKeySchemaTarget); ok {
   521  					s.node = b.modifySchemaTarget(inScope, ts, rt)
   522  					ts.WithPrimaryKeySchema(pkt.PrimaryKeySchema())
   523  				}
   524  			}
   525  		}
   526  	}
   527  	return outScopes
   528  }
   529  
   530  func (b *Builder) buildAlterTableColumnAction(inScope *scope, ddl *ast.DDL, table *plan.ResolvedTable) (outScope *scope) {
   531  	outScope = inScope
   532  	switch strings.ToLower(ddl.ColumnAction) {
   533  	case ast.AddStr:
   534  		sch, _, _ := b.tableSpecToSchema(inScope, outScope, table.Database(), ddl.Table.Name.String(), ddl.TableSpec, true)
   535  		outScope.node = plan.NewAddColumnResolved(table, *sch.Schema[0], columnOrderToColumnOrder(ddl.ColumnOrder))
   536  	case ast.DropStr:
   537  		drop := plan.NewDropColumnResolved(table, ddl.Column.String())
   538  		checks := b.loadChecksFromTable(outScope, table.Table)
   539  		outScope.node = drop.WithChecks(checks)
   540  	case ast.RenameStr:
   541  		rename := plan.NewRenameColumnResolved(table, ddl.Column.String(), ddl.ToColumn.String())
   542  		checks := b.loadChecksFromTable(outScope, table.Table)
   543  		outScope.node = rename.WithChecks(checks)
   544  	case ast.ModifyStr, ast.ChangeStr:
   545  		// modify adds a new column maybe with same name
   546  		// make new hierarchy so it resolves before old column
   547  		outScope = inScope.push()
   548  		sch, _, _ := b.tableSpecToSchema(inScope, outScope, table.Database(), ddl.Table.Name.String(), ddl.TableSpec, true)
   549  		modifyCol := plan.NewModifyColumnResolved(table, ddl.Column.String(), *sch.Schema[0], columnOrderToColumnOrder(ddl.ColumnOrder))
   550  		outScope.node = modifyCol
   551  	default:
   552  		err := sql.ErrUnsupportedFeature.New(ast.String(ddl))
   553  		b.handleErr(err)
   554  	}
   555  
   556  	return outScope
   557  }
   558  
   559  func (b *Builder) buildAlterConstraint(inScope *scope, ddl *ast.DDL, table *plan.ResolvedTable) (outScope *scope) {
   560  	outScope = inScope
   561  	parsedConstraint := b.convertConstraintDefinition(inScope, ddl.TableSpec.Constraints[0])
   562  	switch strings.ToLower(ddl.ConstraintAction) {
   563  	case ast.AddStr:
   564  		switch c := parsedConstraint.(type) {
   565  		case *sql.ForeignKeyConstraint:
   566  			c.Database = table.SqlDatabase.Name()
   567  			c.Table = table.Name()
   568  			alterFk := plan.NewAlterAddForeignKey(c)
   569  			alterFk.DbProvider = b.cat
   570  			outScope.node = alterFk
   571  		case *sql.CheckConstraint:
   572  			outScope.node = plan.NewAlterAddCheck(table, c)
   573  		default:
   574  			err := sql.ErrUnsupportedFeature.New(ast.String(ddl))
   575  			b.handleErr(err)
   576  		}
   577  	case ast.DropStr:
   578  		switch c := parsedConstraint.(type) {
   579  		case *sql.ForeignKeyConstraint:
   580  			database := table.SqlDatabase.Name()
   581  			dropFk := plan.NewAlterDropForeignKey(database, table.Name(), c.Name)
   582  			dropFk.DbProvider = b.cat
   583  			outScope.node = dropFk
   584  		case *sql.CheckConstraint:
   585  			outScope.node = plan.NewAlterDropCheck(table, c.Name)
   586  		case namedConstraint:
   587  			outScope.node = &plan.DropConstraint{
   588  				UnaryNode: plan.UnaryNode{Child: table},
   589  				Name:      c.name,
   590  			}
   591  		default:
   592  			err := sql.ErrUnsupportedFeature.New(ast.String(ddl))
   593  			b.handleErr(err)
   594  		}
   595  	}
   596  	return
   597  }
   598  
   599  func (b *Builder) buildConstraintsDefs(inScope *scope, tname ast.TableName, spec *ast.TableSpec) (fks []*sql.ForeignKeyConstraint, checks []*sql.CheckConstraint) {
   600  	for _, unknownConstraint := range spec.Constraints {
   601  		parsedConstraint := b.convertConstraintDefinition(inScope, unknownConstraint)
   602  		switch constraint := parsedConstraint.(type) {
   603  		case *sql.ForeignKeyConstraint:
   604  			constraint.Database = tname.Qualifier.String()
   605  			constraint.Table = tname.Name.String()
   606  			if constraint.Database == "" {
   607  				constraint.Database = b.ctx.GetCurrentDatabase()
   608  			}
   609  			fks = append(fks, constraint)
   610  		case *sql.CheckConstraint:
   611  			checks = append(checks, constraint)
   612  		default:
   613  			err := sql.ErrUnknownConstraintDefinition.New(unknownConstraint.Name, unknownConstraint)
   614  			b.handleErr(err)
   615  		}
   616  	}
   617  	return
   618  }
   619  
   620  func columnOrderToColumnOrder(order *ast.ColumnOrder) *sql.ColumnOrder {
   621  	if order == nil {
   622  		return nil
   623  	}
   624  	if order.First {
   625  		return &sql.ColumnOrder{First: true}
   626  	} else {
   627  		return &sql.ColumnOrder{AfterColumn: order.AfterColumn.String()}
   628  	}
   629  }
   630  
   631  func (b *Builder) buildIndexDefs(inScope *scope, spec *ast.TableSpec) (idxDefs []*plan.IndexDefinition) {
   632  	for _, idxDef := range spec.Indexes {
   633  		constraint := sql.IndexConstraint_None
   634  		if idxDef.Info.Primary {
   635  			constraint = sql.IndexConstraint_Primary
   636  		} else if idxDef.Info.Unique {
   637  			constraint = sql.IndexConstraint_Unique
   638  		} else if idxDef.Info.Spatial {
   639  			constraint = sql.IndexConstraint_Spatial
   640  		} else if idxDef.Info.Fulltext {
   641  			constraint = sql.IndexConstraint_Fulltext
   642  		}
   643  
   644  		columns := b.gatherIndexColumns(idxDef.Columns)
   645  
   646  		var comment string
   647  		for _, option := range idxDef.Options {
   648  			if strings.ToLower(option.Name) == strings.ToLower(ast.KeywordString(ast.COMMENT_KEYWORD)) {
   649  				comment = string(option.Value.Val)
   650  			}
   651  		}
   652  		idxDefs = append(idxDefs, &plan.IndexDefinition{
   653  			IndexName:  idxDef.Info.Name.String(),
   654  			Using:      sql.IndexUsing_Default, //TODO: add vitess support for USING
   655  			Constraint: constraint,
   656  			Columns:    columns,
   657  			Comment:    comment,
   658  		})
   659  	}
   660  
   661  	for _, colDef := range spec.Columns {
   662  		if colDef.Type.KeyOpt == colKeyFulltextKey {
   663  			idxDefs = append(idxDefs, &plan.IndexDefinition{
   664  				IndexName:  "",
   665  				Using:      sql.IndexUsing_Default,
   666  				Constraint: sql.IndexConstraint_Fulltext,
   667  				Comment:    "",
   668  				Columns: []sql.IndexColumn{{
   669  					Name:   colDef.Name.String(),
   670  					Length: 0,
   671  				}},
   672  			})
   673  		} else if colDef.Type.KeyOpt == colKeyUnique || colDef.Type.KeyOpt == colKeyUniqueKey {
   674  			idxDefs = append(idxDefs, &plan.IndexDefinition{
   675  				IndexName:  "",
   676  				Using:      sql.IndexUsing_Default,
   677  				Constraint: sql.IndexConstraint_Unique,
   678  				Comment:    "",
   679  				Columns: []sql.IndexColumn{{
   680  					Name:   colDef.Name.String(),
   681  					Length: 0,
   682  				}},
   683  			})
   684  		}
   685  	}
   686  	return
   687  }
   688  
   689  type namedConstraint struct {
   690  	name string
   691  }
   692  
   693  func (b *Builder) convertConstraintDefinition(inScope *scope, cd *ast.ConstraintDefinition) interface{} {
   694  	if fkConstraint, ok := cd.Details.(*ast.ForeignKeyDefinition); ok {
   695  		columns := make([]string, len(fkConstraint.Source))
   696  		for i, col := range fkConstraint.Source {
   697  			columns[i] = col.String()
   698  		}
   699  		refColumns := make([]string, len(fkConstraint.ReferencedColumns))
   700  		for i, col := range fkConstraint.ReferencedColumns {
   701  			refColumns[i] = col.String()
   702  		}
   703  		refDatabase := fkConstraint.ReferencedTable.Qualifier.String()
   704  		if refDatabase == "" {
   705  			refDatabase = b.ctx.GetCurrentDatabase()
   706  		}
   707  		// The database and table are set in the calling function
   708  		return &sql.ForeignKeyConstraint{
   709  			Name:           cd.Name,
   710  			Columns:        columns,
   711  			ParentDatabase: refDatabase,
   712  			ParentTable:    fkConstraint.ReferencedTable.Name.String(),
   713  			ParentColumns:  refColumns,
   714  			OnUpdate:       b.buildReferentialAction(fkConstraint.OnUpdate),
   715  			OnDelete:       b.buildReferentialAction(fkConstraint.OnDelete),
   716  			IsResolved:     false,
   717  		}
   718  	} else if chConstraint, ok := cd.Details.(*ast.CheckConstraintDefinition); ok {
   719  		var c sql.Expression
   720  		if chConstraint.Expr != nil {
   721  			c = b.buildScalar(inScope, chConstraint.Expr)
   722  		}
   723  
   724  		return &sql.CheckConstraint{
   725  			Name:     cd.Name,
   726  			Expr:     c,
   727  			Enforced: chConstraint.Enforced,
   728  		}
   729  	} else if len(cd.Name) > 0 && cd.Details == nil {
   730  		return namedConstraint{cd.Name}
   731  	}
   732  	err := sql.ErrUnknownConstraintDefinition.New(cd.Name, cd)
   733  	b.handleErr(err)
   734  	return nil
   735  }
   736  
   737  func (b *Builder) buildReferentialAction(action ast.ReferenceAction) sql.ForeignKeyReferentialAction {
   738  	switch action {
   739  	case ast.Restrict:
   740  		return sql.ForeignKeyReferentialAction_Restrict
   741  	case ast.Cascade:
   742  		return sql.ForeignKeyReferentialAction_Cascade
   743  	case ast.NoAction:
   744  		return sql.ForeignKeyReferentialAction_NoAction
   745  	case ast.SetNull:
   746  		return sql.ForeignKeyReferentialAction_SetNull
   747  	case ast.SetDefault:
   748  		return sql.ForeignKeyReferentialAction_SetDefault
   749  	default:
   750  		return sql.ForeignKeyReferentialAction_DefaultAction
   751  	}
   752  }
   753  
   754  func (b *Builder) buildAlterIndex(inScope *scope, ddl *ast.DDL, table *plan.ResolvedTable) (outScope *scope) {
   755  	outScope = inScope
   756  	switch strings.ToLower(ddl.IndexSpec.Action) {
   757  	case ast.CreateStr:
   758  		var using sql.IndexUsing
   759  		switch ddl.IndexSpec.Using.Lowered() {
   760  		case "", "btree":
   761  			using = sql.IndexUsing_BTree
   762  		case "hash":
   763  			using = sql.IndexUsing_Hash
   764  		default:
   765  			return b.buildExternalCreateIndex(inScope, ddl)
   766  		}
   767  
   768  		var constraint sql.IndexConstraint
   769  		switch ddl.IndexSpec.Type {
   770  		case ast.UniqueStr:
   771  			constraint = sql.IndexConstraint_Unique
   772  		case ast.FulltextStr:
   773  			constraint = sql.IndexConstraint_Fulltext
   774  		case ast.SpatialStr:
   775  			constraint = sql.IndexConstraint_Spatial
   776  		case ast.PrimaryStr:
   777  			constraint = sql.IndexConstraint_Primary
   778  		default:
   779  			constraint = sql.IndexConstraint_None
   780  		}
   781  
   782  		columns := b.gatherIndexColumns(ddl.IndexSpec.Columns)
   783  
   784  		var comment string
   785  		for _, option := range ddl.IndexSpec.Options {
   786  			if strings.ToLower(option.Name) == strings.ToLower(ast.KeywordString(ast.COMMENT_KEYWORD)) {
   787  				comment = string(option.Value.Val)
   788  			}
   789  		}
   790  
   791  		if constraint == sql.IndexConstraint_Primary {
   792  			outScope.node = plan.NewAlterCreatePk(table.SqlDatabase, table, columns)
   793  			return
   794  		}
   795  
   796  		indexName := ddl.IndexSpec.ToName.String()
   797  		if strings.ToLower(indexName) == ast.PrimaryStr {
   798  			err := sql.ErrInvalidIndexName.New(indexName)
   799  			b.handleErr(err)
   800  		}
   801  
   802  		createIndex := plan.NewAlterCreateIndex(table.SqlDatabase, table, ddl.IndexSpec.ToName.String(), using, constraint, columns, comment)
   803  		outScope.node = b.modifySchemaTarget(inScope, createIndex, table)
   804  		return
   805  	case ast.DropStr:
   806  		if ddl.IndexSpec.Type == ast.PrimaryStr {
   807  			outScope.node = plan.NewAlterDropPk(table.SqlDatabase, table)
   808  			return
   809  		}
   810  		outScope.node = plan.NewAlterDropIndex(table.Database(), table, ddl.IndexSpec.ToName.String())
   811  		return
   812  	case ast.RenameStr:
   813  		outScope.node = plan.NewAlterRenameIndex(table.Database(), table, ddl.IndexSpec.FromName.String(), ddl.IndexSpec.ToName.String())
   814  		return
   815  	case "disable":
   816  		outScope.node = plan.NewAlterDisableEnableKeys(table.SqlDatabase, table, true)
   817  		return
   818  	case "enable":
   819  		outScope.node = plan.NewAlterDisableEnableKeys(table.SqlDatabase, table, false)
   820  		return
   821  	default:
   822  		err := sql.ErrUnsupportedFeature.New(ast.String(ddl))
   823  		b.handleErr(err)
   824  	}
   825  	return
   826  }
   827  
   828  func (b *Builder) gatherIndexColumns(cols []*ast.IndexColumn) []sql.IndexColumn {
   829  	out := make([]sql.IndexColumn, len(cols))
   830  	for i, col := range cols {
   831  		var length int64
   832  		var err error
   833  		if col.Length != nil && col.Length.Type == ast.IntVal {
   834  			length, err = strconv.ParseInt(string(col.Length.Val), 10, 64)
   835  			if err != nil {
   836  				b.handleErr(err)
   837  			}
   838  			if length < 1 {
   839  				err := sql.ErrKeyZero.New(col.Column)
   840  				b.handleErr(err)
   841  			}
   842  		}
   843  		out[i] = sql.IndexColumn{
   844  			Name:   col.Column.String(),
   845  			Length: length,
   846  		}
   847  	}
   848  	return out
   849  }
   850  
   851  func (b *Builder) buildAlterAutoIncrement(inScope *scope, ddl *ast.DDL, table *plan.ResolvedTable) (outScope *scope) {
   852  	outScope = inScope
   853  	val, ok := ddl.AutoIncSpec.Value.(*ast.SQLVal)
   854  	if !ok {
   855  		err := sql.ErrInvalidSQLValType.New(ddl.AutoIncSpec.Value)
   856  		b.handleErr(err)
   857  	}
   858  
   859  	var autoVal uint64
   860  	if val.Type == ast.IntVal {
   861  		i, err := strconv.ParseUint(string(val.Val), 10, 64)
   862  		if err != nil {
   863  			b.handleErr(err)
   864  		}
   865  		autoVal = i
   866  	} else if val.Type == ast.FloatVal {
   867  		f, err := strconv.ParseFloat(string(val.Val), 10)
   868  		if err != nil {
   869  			b.handleErr(err)
   870  		}
   871  		autoVal = uint64(f)
   872  	} else {
   873  		err := sql.ErrInvalidSQLValType.New(ddl.AutoIncSpec.Value)
   874  		b.handleErr(err)
   875  	}
   876  
   877  	outScope.node = plan.NewAlterAutoIncrement(table.Database(), table, autoVal)
   878  	return
   879  }
   880  
   881  func (b *Builder) buildAlterDefault(inScope *scope, ddl *ast.DDL, table *plan.ResolvedTable) (outScope *scope) {
   882  	outScope = inScope
   883  	switch strings.ToLower(ddl.DefaultSpec.Action) {
   884  	case ast.SetStr:
   885  		for _, c := range table.Schema() {
   886  			if strings.EqualFold(c.Name, ddl.DefaultSpec.Column.String()) {
   887  				defaultExpr := b.convertDefaultExpression(inScope, ddl.DefaultSpec.Value, c.Type, c.Nullable)
   888  				defSet := plan.NewAlterDefaultSet(table.Database(), table, ddl.DefaultSpec.Column.String(), defaultExpr)
   889  				outScope.node = b.modifySchemaTarget(inScope, defSet, table)
   890  				return
   891  			}
   892  		}
   893  		err := sql.ErrTableColumnNotFound.New(table.Name(), ddl.DefaultSpec.Column.String())
   894  		b.handleErr(err)
   895  		return
   896  	case ast.DropStr:
   897  		outScope.node = plan.NewAlterDefaultDrop(table.Database(), table, ddl.DefaultSpec.Column.String())
   898  		return
   899  	default:
   900  		err := sql.ErrUnsupportedFeature.New(ast.String(ddl))
   901  		b.handleErr(err)
   902  	}
   903  	return
   904  }
   905  
   906  func (b *Builder) buildAlterCollationSpec(inScope *scope, ddl *ast.DDL, table *plan.ResolvedTable) (outScope *scope) {
   907  	outScope = inScope
   908  	var charSetStr *string
   909  	var collationStr *string
   910  	if len(ddl.AlterCollationSpec.CharacterSet) > 0 {
   911  		charSetStr = &ddl.AlterCollationSpec.CharacterSet
   912  	}
   913  	if len(ddl.AlterCollationSpec.Collation) > 0 {
   914  		collationStr = &ddl.AlterCollationSpec.Collation
   915  	}
   916  	collation, err := sql.ParseCollation(charSetStr, collationStr, false)
   917  	if err != nil {
   918  		b.handleErr(err)
   919  	}
   920  	outScope.node = plan.NewAlterTableCollationResolved(table, collation)
   921  	return
   922  }
   923  
   924  func (b *Builder) buildDefaultExpression(inScope *scope, defaultExpr ast.Expr) *sql.ColumnDefaultValue {
   925  	if defaultExpr == nil {
   926  		return nil
   927  	}
   928  	parsedExpr := b.buildScalar(inScope, defaultExpr)
   929  
   930  	// Function expressions must be enclosed in parentheses (except for current_timestamp() and now())
   931  	_, isParenthesized := defaultExpr.(*ast.ParenExpr)
   932  	isLiteral := !isParenthesized
   933  
   934  	// A literal will never have children, thus we can also check for that.
   935  	if unaryExpr, is := defaultExpr.(*ast.UnaryExpr); is {
   936  		if _, lit := unaryExpr.Expr.(*ast.SQLVal); lit {
   937  			isLiteral = true
   938  		}
   939  	} else if !isParenthesized {
   940  		if f, ok := parsedExpr.(*expression.UnresolvedFunction); ok {
   941  			// Datetime and Timestamp columns allow now and current_timestamp to not be enclosed in parens,
   942  			// but they still need to be treated as function expressions
   943  			switch strings.ToLower(f.Name()) {
   944  			case "now", "current_timestamp", "localtime", "localtimestamp":
   945  				isLiteral = false
   946  			default:
   947  				err := sql.ErrSyntaxError.New("column default function expressions must be enclosed in parentheses")
   948  				b.handleErr(err)
   949  			}
   950  		}
   951  	}
   952  
   953  	return ExpressionToColumnDefaultValue(parsedExpr, isLiteral, isParenthesized)
   954  }
   955  
   956  // ExpressionToColumnDefaultValue takes in an Expression and returns the equivalent ColumnDefaultValue if the expression
   957  // is valid for a default value. If the expression represents a literal (and not an expression that returns a literal, so "5"
   958  // rather than "(5)"), then the parameter "isLiteral" should be true.
   959  func ExpressionToColumnDefaultValue(inputExpr sql.Expression, isLiteral, isParenthesized bool) *sql.ColumnDefaultValue {
   960  	return &sql.ColumnDefaultValue{
   961  		Expr:          inputExpr,
   962  		OutType:       nil,
   963  		Literal:       isLiteral,
   964  		ReturnNil:     true,
   965  		Parenthesized: isParenthesized,
   966  	}
   967  }
   968  
   969  func (b *Builder) buildExternalCreateIndex(inScope *scope, ddl *ast.DDL) (outScope *scope) {
   970  	config := make(map[string]string)
   971  	for _, option := range ddl.IndexSpec.Options {
   972  		if option.Using != "" {
   973  			config[option.Name] = option.Using
   974  		} else {
   975  			config[option.Name] = string(option.Value.Val)
   976  		}
   977  	}
   978  
   979  	dbName := strings.ToLower(ddl.Table.Qualifier.String())
   980  	tblName := strings.ToLower(ddl.Table.Name.String())
   981  	var ok bool
   982  	outScope, ok = b.buildTablescan(inScope, dbName, tblName, nil)
   983  	if !ok {
   984  		b.handleErr(sql.ErrTableNotFound.New(tblName))
   985  	}
   986  	table, ok := outScope.node.(*plan.ResolvedTable)
   987  	if !ok {
   988  		err := fmt.Errorf("expected resolved table: %s", tblName)
   989  		b.handleErr(err)
   990  	}
   991  
   992  	tableId := outScope.tables[tblName]
   993  	cols := make([]sql.Expression, len(ddl.IndexSpec.Columns))
   994  	for i, col := range ddl.IndexSpec.Columns {
   995  		colName := strings.ToLower(col.Column.String())
   996  		c, ok := inScope.resolveColumn(dbName, tblName, colName, true, false)
   997  		if !ok {
   998  			b.handleErr(sql.ErrColumnNotFound.New(colName))
   999  		}
  1000  		cols[i] = expression.NewGetFieldWithTable(int(c.id), int(tableId), c.typ, c.db, c.table, c.col, c.nullable)
  1001  	}
  1002  
  1003  	createIndex := plan.NewCreateIndex(
  1004  		ddl.IndexSpec.ToName.String(),
  1005  		table,
  1006  		cols,
  1007  		ddl.IndexSpec.Using.Lowered(),
  1008  		config,
  1009  	)
  1010  	createIndex.Catalog = b.cat
  1011  	outScope.node = createIndex
  1012  	return
  1013  }
  1014  
  1015  // validateOnUpdateExprs ensures that the Time functions used for OnUpdate for columns is correct
  1016  func validateOnUpdateExprs(col *sql.Column) error {
  1017  	if col.OnUpdate == nil {
  1018  		return nil
  1019  	}
  1020  	if !(types.IsDatetimeType(col.Type) || types.IsTimestampType(col.Type)) {
  1021  		return sql.ErrInvalidOnUpdate.New(col.Name)
  1022  	}
  1023  	now, ok := col.OnUpdate.Expr.(*function.Now)
  1024  	if !ok {
  1025  		return nil
  1026  	}
  1027  	children := now.Children()
  1028  	if len(children) == 0 {
  1029  		return nil
  1030  	}
  1031  	lit, isLit := children[0].(*expression.Literal)
  1032  	if !isLit {
  1033  		return nil
  1034  	}
  1035  	val, err := lit.Eval(nil, nil)
  1036  	if err != nil {
  1037  		return err
  1038  	}
  1039  	prec, ok := types.CoalesceInt(val)
  1040  	if !ok {
  1041  		return sql.ErrInvalidOnUpdate.New(col.Name)
  1042  	}
  1043  	if prec != 0 {
  1044  		return sql.ErrInvalidOnUpdate.New(col.Name)
  1045  	}
  1046  	return nil
  1047  }
  1048  
  1049  // TableSpecToSchema creates a sql.Schema from a parsed TableSpec and returns the parsed primary key schema, collation ID, and table comment.
  1050  func (b *Builder) tableSpecToSchema(inScope, outScope *scope, db sql.Database, tableName string, tableSpec *ast.TableSpec, forceInvalidCollation bool) (sql.PrimaryKeySchema, sql.CollationID, string) {
  1051  	// todo: somewhere downstream updates an ALTER MODIY column's type collation
  1052  	// to match the underlying. That only happens if the type stays unspecified.
  1053  	tableCollation := sql.Collation_Unspecified
  1054  	tableComment := ""
  1055  	if !forceInvalidCollation {
  1056  		tableCollation = sql.Collation_Default
  1057  		if cdb, _ := db.(sql.CollatedDatabase); cdb != nil {
  1058  			tableCollation = cdb.GetCollation(b.ctx)
  1059  		}
  1060  		if len(tableSpec.Options) > 0 {
  1061  			charsetSubmatches := tableCharsetOptionRegex.FindStringSubmatch(tableSpec.Options)
  1062  			collationSubmatches := tableCollationOptionRegex.FindStringSubmatch(tableSpec.Options)
  1063  			commentSubmatches := tableCommentOptionRegex.FindStringSubmatch(tableSpec.Options)
  1064  			if len(charsetSubmatches) == 5 && len(collationSubmatches) == 5 {
  1065  				var err error
  1066  				tableCollation, err = sql.ParseCollation(&charsetSubmatches[4], &collationSubmatches[4], false)
  1067  				if err != nil {
  1068  					return sql.PrimaryKeySchema{}, sql.Collation_Unspecified, ""
  1069  				}
  1070  			} else if len(charsetSubmatches) == 5 {
  1071  				charset, err := sql.ParseCharacterSet(charsetSubmatches[4])
  1072  				if err != nil {
  1073  					return sql.PrimaryKeySchema{}, sql.Collation_Unspecified, ""
  1074  				}
  1075  				tableCollation = charset.DefaultCollation()
  1076  			} else if len(collationSubmatches) == 5 {
  1077  				var err error
  1078  				tableCollation, err = sql.ParseCollation(nil, &collationSubmatches[4], false)
  1079  				if err != nil {
  1080  					return sql.PrimaryKeySchema{}, sql.Collation_Unspecified, ""
  1081  				}
  1082  			}
  1083  			if len(commentSubmatches) == 5 {
  1084  				tableComment = commentSubmatches[4]
  1085  			}
  1086  		}
  1087  	}
  1088  
  1089  	tabId := outScope.addTable(tableName)
  1090  
  1091  	defaults := make([]ast.Expr, len(tableSpec.Columns))
  1092  	generated := make([]ast.Expr, len(tableSpec.Columns))
  1093  	updates := make([]ast.Expr, len(tableSpec.Columns))
  1094  	var schema sql.Schema
  1095  	for i, cd := range tableSpec.Columns {
  1096  		if cd.Type.ResolvedType == nil {
  1097  			sqlType := cd.Type.SQLType()
  1098  			// Use the table's collation if no character or collation was specified for the table
  1099  			if len(cd.Type.Charset) == 0 && len(cd.Type.Collate) == 0 {
  1100  				if tableCollation != sql.Collation_Unspecified && !types.IsBinary(sqlType) {
  1101  					cd.Type.Collate = tableCollation.Name()
  1102  				}
  1103  			}
  1104  		}
  1105  		defaults[i] = cd.Type.Default
  1106  		generated[i] = cd.Type.GeneratedExpr
  1107  		updates[i] = cd.Type.OnUpdate
  1108  
  1109  		column := b.columnDefinitionToColumn(inScope, cd, tableSpec.Indexes)
  1110  		column.DatabaseSource = db.Name()
  1111  
  1112  		if column.PrimaryKey && bool(cd.Type.Null) {
  1113  			b.handleErr(ErrPrimaryKeyOnNullField.New())
  1114  		}
  1115  
  1116  		schema = append(schema, column)
  1117  		outScope.newColumn(scopeColumn{
  1118  			tableId:  tabId,
  1119  			table:    tableName,
  1120  			db:       db.Name(),
  1121  			col:      strings.ToLower(column.Name),
  1122  			typ:      column.Type,
  1123  			nullable: column.Nullable,
  1124  		})
  1125  	}
  1126  
  1127  	for i, def := range defaults {
  1128  		schema[i].Default = b.convertDefaultExpression(outScope, def, schema[i].Type, schema[i].Nullable)
  1129  		if def != nil && generated[i] != nil {
  1130  			b.handleErr(sql.ErrGeneratedColumnWithDefault.New())
  1131  			return sql.PrimaryKeySchema{}, sql.Collation_Unspecified, ""
  1132  		}
  1133  	}
  1134  
  1135  	for i, gen := range generated {
  1136  		if gen != nil {
  1137  			virtual := !bool(tableSpec.Columns[i].Type.Stored)
  1138  			schema[i].Generated = b.convertDefaultExpression(outScope, gen, schema[i].Type, schema[i].Nullable)
  1139  			// generated expressions are always parenthesized, but we don't record this in the parser
  1140  			schema[i].Generated.Parenthesized = true
  1141  			schema[i].Generated.Literal = false
  1142  			schema[i].Virtual = virtual
  1143  		}
  1144  	}
  1145  
  1146  	for i, onUpdateExpr := range updates {
  1147  		schema[i].OnUpdate = b.convertDefaultExpression(outScope, onUpdateExpr, schema[i].Type, schema[i].Nullable)
  1148  		err := validateOnUpdateExprs(schema[i])
  1149  		if err != nil {
  1150  			b.handleErr(err)
  1151  		}
  1152  	}
  1153  
  1154  	pkSch := sql.NewPrimaryKeySchema(schema, getPkOrdinals(tableSpec)...)
  1155  	return pkSch, tableCollation, tableComment
  1156  }
  1157  
  1158  // jsonTableSpecToSchemaHelper creates a sql.Schema from a parsed TableSpec
  1159  func (b *Builder) jsonTableSpecToSchemaHelper(jsonTableSpec *ast.JSONTableSpec, sch sql.Schema) {
  1160  	for _, cd := range jsonTableSpec.Columns {
  1161  		if cd.Spec != nil {
  1162  			b.jsonTableSpecToSchemaHelper(cd.Spec, sch)
  1163  			continue
  1164  		}
  1165  		typ, err := types.ColumnTypeToType(&cd.Type)
  1166  		if err != nil {
  1167  			b.handleErr(err)
  1168  		}
  1169  		col := &sql.Column{
  1170  			Type:          typ,
  1171  			Name:          cd.Name.String(),
  1172  			AutoIncrement: bool(cd.Type.Autoincrement),
  1173  		}
  1174  		sch = append(sch, col)
  1175  		continue
  1176  	}
  1177  }
  1178  
  1179  // jsonTableSpecToSchema creates a sql.Schema from a parsed TableSpec
  1180  func (b *Builder) jsonTableSpecToSchema(tableSpec *ast.JSONTableSpec) sql.Schema {
  1181  	var sch sql.Schema
  1182  	b.jsonTableSpecToSchemaHelper(tableSpec, sch)
  1183  	return sch
  1184  }
  1185  
  1186  // These constants aren't exported from vitess for some reason. This could be removed if we changed this.
  1187  const (
  1188  	colKeyNone ast.ColumnKeyOption = iota
  1189  	colKeyPrimary
  1190  	colKeySpatialKey
  1191  	colKeyUnique
  1192  	colKeyUniqueKey
  1193  	colKey
  1194  	colKeyFulltextKey
  1195  )
  1196  
  1197  func getPkOrdinals(ts *ast.TableSpec) []int {
  1198  	for _, idxDef := range ts.Indexes {
  1199  		if idxDef.Info.Primary {
  1200  
  1201  			pkOrdinals := make([]int, 0)
  1202  			colIdx := make(map[string]int)
  1203  			for i := 0; i < len(ts.Columns); i++ {
  1204  				colIdx[ts.Columns[i].Name.Lowered()] = i
  1205  			}
  1206  
  1207  			for _, i := range idxDef.Columns {
  1208  				pkOrdinals = append(pkOrdinals, colIdx[i.Column.Lowered()])
  1209  			}
  1210  
  1211  			return pkOrdinals
  1212  		}
  1213  	}
  1214  
  1215  	// no primary key expression, check for inline PK column
  1216  	for i, col := range ts.Columns {
  1217  		if col.Type.KeyOpt == colKeyPrimary {
  1218  			return []int{i}
  1219  		}
  1220  	}
  1221  
  1222  	return []int{}
  1223  }
  1224  
  1225  // columnDefinitionToColumn returns the sql.Column for the column definition given, as part of a create table
  1226  // statement. Defaults and generated expressions must be handled separately.
  1227  func (b *Builder) columnDefinitionToColumn(inScope *scope, cd *ast.ColumnDefinition, indexes []*ast.IndexDefinition) *sql.Column {
  1228  	internalTyp, err := types.ColumnTypeToType(&cd.Type)
  1229  	if err != nil {
  1230  		b.handleErr(err)
  1231  	}
  1232  
  1233  	// Primary key info can either be specified in the column's type info (for in-line declarations), or in a slice of
  1234  	// indexes attached to the table def. We have to check both places to find if a column is part of the primary key
  1235  	isPkey := cd.Type.KeyOpt == colKeyPrimary
  1236  
  1237  	if !isPkey {
  1238  	OuterLoop:
  1239  		for _, index := range indexes {
  1240  			if index.Info.Primary {
  1241  				for _, indexCol := range index.Columns {
  1242  					if indexCol.Column.Equal(cd.Name) {
  1243  						isPkey = true
  1244  						break OuterLoop
  1245  					}
  1246  				}
  1247  			}
  1248  		}
  1249  	}
  1250  
  1251  	var comment string
  1252  	if cd.Type.Comment != nil && cd.Type.Comment.Type == ast.StrVal {
  1253  		comment = string(cd.Type.Comment.Val)
  1254  	}
  1255  
  1256  	nullable := !isPkey && !bool(cd.Type.NotNull)
  1257  	extra := ""
  1258  
  1259  	if cd.Type.Autoincrement {
  1260  		extra = "auto_increment"
  1261  	}
  1262  
  1263  	if cd.Type.SRID != nil {
  1264  		sridVal, err := strconv.ParseInt(string(cd.Type.SRID.Val), 10, 32)
  1265  		if err != nil {
  1266  			b.handleErr(err)
  1267  		}
  1268  
  1269  		if err = types.ValidateSRID(int(sridVal), ""); err != nil {
  1270  			b.handleErr(err)
  1271  		}
  1272  		if s, ok := internalTyp.(sql.SpatialColumnType); ok {
  1273  			internalTyp = s.SetSRID(uint32(sridVal))
  1274  		} else {
  1275  			b.handleErr(sql.ErrInvalidType.New(fmt.Sprintf("cannot define SRID for %s", internalTyp)))
  1276  		}
  1277  	}
  1278  
  1279  	return &sql.Column{
  1280  		Name:          cd.Name.String(),
  1281  		Type:          internalTyp,
  1282  		AutoIncrement: bool(cd.Type.Autoincrement),
  1283  		Nullable:      nullable,
  1284  		PrimaryKey:    isPkey,
  1285  		Comment:       comment,
  1286  		Extra:         extra,
  1287  	}
  1288  }
  1289  
  1290  func (b *Builder) modifySchemaTarget(inScope *scope, n sql.SchemaTarget, rt *plan.ResolvedTable) sql.Node {
  1291  	targSchema := b.resolveSchemaDefaults(inScope, rt.Schema())
  1292  	ret, err := n.WithTargetSchema(targSchema)
  1293  	if err != nil {
  1294  		b.handleErr(err)
  1295  	}
  1296  	return ret
  1297  }
  1298  
  1299  func (b *Builder) resolveSchemaDefaults(inScope *scope, schema sql.Schema) sql.Schema {
  1300  	if len(schema) == 0 {
  1301  		return nil
  1302  	}
  1303  	if len(inScope.cols) < len(schema) {
  1304  		// alter statements only add definitions for modified columns
  1305  		// backfill rest of columns
  1306  		resolveScope := inScope.replace()
  1307  		for _, col := range schema {
  1308  			resolveScope.newColumn(scopeColumn{
  1309  				db:       col.DatabaseSource,
  1310  				table:    strings.ToLower(col.Source),
  1311  				col:      strings.ToLower(col.Name),
  1312  				typ:      col.Type,
  1313  				nullable: col.Nullable,
  1314  			})
  1315  		}
  1316  		inScope = resolveScope
  1317  	}
  1318  
  1319  	newSch := schema.Copy()
  1320  	for _, part := range partitionTableColumns(newSch) {
  1321  		start := part[0]
  1322  		end := part[1]
  1323  		subScope := inScope.replace()
  1324  		for i := start; i < end; i++ {
  1325  			subScope.addColumn(inScope.cols[i])
  1326  		}
  1327  		for _, col := range newSch[start:end] {
  1328  			col.Default = b.resolveColumnDefaultExpression(subScope, col, col.Default)
  1329  			col.Generated = b.resolveColumnDefaultExpression(subScope, col, col.Generated)
  1330  			col.OnUpdate = b.resolveColumnDefaultExpression(subScope, col, col.OnUpdate)
  1331  		}
  1332  	}
  1333  	return newSch
  1334  }
  1335  
  1336  // partitionTableColumns splits a sql.Schema into a list
  1337  // of [2]int{start,end} ranges that each partition the tables
  1338  // included in the schema.
  1339  func partitionTableColumns(sch sql.Schema) [][2]int {
  1340  	var ret [][2]int
  1341  	var i int = 1
  1342  	var prevI int = 0
  1343  	for i < len(sch) {
  1344  		if strings.EqualFold(sch[i-1].Source, sch[i].Source) &&
  1345  			strings.EqualFold(sch[i-1].DatabaseSource, sch[i].DatabaseSource) {
  1346  			i++
  1347  			continue
  1348  		}
  1349  		ret = append(ret, [2]int{prevI, i})
  1350  		prevI = i
  1351  		i++
  1352  	}
  1353  	ret = append(ret, [2]int{prevI, i})
  1354  	return ret
  1355  }
  1356  
  1357  func (b *Builder) resolveColumnDefaultExpression(inScope *scope, columnDef *sql.Column, colDefault *sql.ColumnDefaultValue) *sql.ColumnDefaultValue {
  1358  	if colDefault == nil || colDefault.Expr == nil {
  1359  		return colDefault
  1360  	}
  1361  
  1362  	def, ok := colDefault.Expr.(*sql.UnresolvedColumnDefault)
  1363  	if !ok {
  1364  		// no resolution work to be done, return the original value
  1365  		return colDefault
  1366  	}
  1367  
  1368  	// Empty string is a special case, it means the default value is the empty string
  1369  	// TODO: why isn't this serialized as ''
  1370  	if def.String() == "" {
  1371  		return b.convertDefaultExpression(inScope, &ast.SQLVal{Val: []byte{}, Type: ast.StrVal}, columnDef.Type, columnDef.Nullable)
  1372  	}
  1373  
  1374  	parsed, err := ast.Parse(fmt.Sprintf("SELECT %s", def))
  1375  	if err != nil {
  1376  		err := fmt.Errorf("%w: %s", sql.ErrInvalidColumnDefaultValue.New(def), err)
  1377  		b.handleErr(err)
  1378  	}
  1379  
  1380  	selectStmt, ok := parsed.(*ast.Select)
  1381  	if !ok || len(selectStmt.SelectExprs) != 1 {
  1382  		err := sql.ErrInvalidColumnDefaultValue.New(def)
  1383  		b.handleErr(err)
  1384  	}
  1385  
  1386  	expr := selectStmt.SelectExprs[0]
  1387  	ae, ok := expr.(*ast.AliasedExpr)
  1388  	if !ok {
  1389  		err := sql.ErrInvalidColumnDefaultValue.New(def)
  1390  		b.handleErr(err)
  1391  	}
  1392  
  1393  	return b.convertDefaultExpression(inScope, ae.Expr, columnDef.Type, columnDef.Nullable)
  1394  }
  1395  
  1396  func (b *Builder) convertDefaultExpression(inScope *scope, defaultExpr ast.Expr, typ sql.Type, nullable bool) *sql.ColumnDefaultValue {
  1397  	if defaultExpr == nil {
  1398  		return nil
  1399  	}
  1400  	resExpr := b.buildScalar(inScope, defaultExpr)
  1401  
  1402  	// Function expressions must be enclosed in parentheses (except for current_timestamp() and now())
  1403  	_, isParenthesized := defaultExpr.(*ast.ParenExpr)
  1404  	isLiteral := !isParenthesized
  1405  
  1406  	// A literal will never have children, thus we can also check for that.
  1407  	if unaryExpr, is := defaultExpr.(*ast.UnaryExpr); is {
  1408  		if _, lit := unaryExpr.Expr.(*ast.SQLVal); lit {
  1409  			isLiteral = true
  1410  		}
  1411  	} else if !isParenthesized {
  1412  		if _, ok := resExpr.(sql.FunctionExpression); ok {
  1413  			switch resExpr.(type) {
  1414  			case *function.Now:
  1415  				// Datetime and Timestamp columns allow now and current_timestamp to not be enclosed in parens,
  1416  				// but they still need to be treated as function expressions
  1417  				isLiteral = false
  1418  			default:
  1419  				// All other functions must *always* be enclosed in parens
  1420  				err := sql.ErrSyntaxError.New("column default function expressions must be enclosed in parentheses")
  1421  				b.handleErr(err)
  1422  			}
  1423  		}
  1424  	}
  1425  
  1426  	// TODO: fix the vitess parser so that it parses negative numbers as numbers and not negation of an expression
  1427  	if unaryMinusExpr, ok := resExpr.(*expression.UnaryMinus); ok {
  1428  		if literalExpr, ok := unaryMinusExpr.Child.(*expression.Literal); ok {
  1429  			switch val := literalExpr.Value().(type) {
  1430  			case float32:
  1431  				resExpr = expression.NewLiteral(-val, types.Float32)
  1432  				isLiteral = true
  1433  			case float64:
  1434  				resExpr = expression.NewLiteral(-val, types.Float64)
  1435  				isLiteral = true
  1436  			}
  1437  		}
  1438  	}
  1439  
  1440  	return &sql.ColumnDefaultValue{
  1441  		Expr:          resExpr,
  1442  		OutType:       typ,
  1443  		Literal:       isLiteral,
  1444  		ReturnNil:     nullable,
  1445  		Parenthesized: isParenthesized,
  1446  	}
  1447  }
  1448  
  1449  func (b *Builder) buildDBDDL(inScope *scope, c *ast.DBDDL) (outScope *scope) {
  1450  	outScope = inScope.push()
  1451  	switch strings.ToLower(c.Action) {
  1452  	case ast.CreateStr:
  1453  		var charsetStr *string
  1454  		var collationStr *string
  1455  		for _, cc := range c.CharsetCollate {
  1456  			ccType := strings.ToLower(cc.Type)
  1457  			if ccType == "character set" {
  1458  				val := cc.Value
  1459  				charsetStr = &val
  1460  			} else if ccType == "collate" {
  1461  				val := cc.Value
  1462  				collationStr = &val
  1463  			} else if b.ctx != nil && b.ctx.Session != nil {
  1464  				b.ctx.Session.Warn(&sql.Warning{
  1465  					Level:   "Warning",
  1466  					Code:    mysql.ERNotSupportedYet,
  1467  					Message: "Setting CHARACTER SET, COLLATION and ENCRYPTION are not supported yet",
  1468  				})
  1469  			}
  1470  		}
  1471  		collation, err := sql.ParseCollation(charsetStr, collationStr, false)
  1472  		if err != nil {
  1473  			b.handleErr(err)
  1474  		}
  1475  		createDb := plan.NewCreateDatabase(c.DBName, c.IfNotExists, collation)
  1476  		createDb.Catalog = b.cat
  1477  		outScope.node = createDb
  1478  	case ast.DropStr:
  1479  		dropDb := plan.NewDropDatabase(c.DBName, c.IfExists)
  1480  		dropDb.Catalog = b.cat
  1481  		outScope.node = dropDb
  1482  	case ast.AlterStr:
  1483  		if len(c.CharsetCollate) == 0 {
  1484  			if len(c.DBName) > 0 {
  1485  				err := sql.ErrSyntaxError.New(fmt.Sprintf("alter database %s", c.DBName))
  1486  				b.handleErr(err)
  1487  			} else {
  1488  				err := sql.ErrSyntaxError.New("alter database")
  1489  				b.handleErr(err)
  1490  			}
  1491  		}
  1492  
  1493  		var charsetStr *string
  1494  		var collationStr *string
  1495  		for _, cc := range c.CharsetCollate {
  1496  			ccType := strings.ToLower(cc.Type)
  1497  			if ccType == "character set" {
  1498  				val := cc.Value
  1499  				charsetStr = &val
  1500  			} else if ccType == "collate" {
  1501  				val := cc.Value
  1502  				collationStr = &val
  1503  			}
  1504  		}
  1505  		collation, err := sql.ParseCollation(charsetStr, collationStr, false)
  1506  		if err != nil {
  1507  			b.handleErr(err)
  1508  		}
  1509  		alterDb := plan.NewAlterDatabase(c.DBName, collation)
  1510  		alterDb.Catalog = b.cat
  1511  		outScope.node = alterDb
  1512  	default:
  1513  		err := sql.ErrUnsupportedSyntax.New(ast.String(c))
  1514  		b.handleErr(err)
  1515  	}
  1516  	return outScope
  1517  }
  1518  
  1519  // ExtendedTypeTag is primarily used by ParseColumnTypeString when parsing strings representing extended types
  1520  const ExtendedTypeTag = "extended_"
  1521  
  1522  func ParseColumnTypeString(columnType string) (sql.Type, error) {
  1523  	if strings.HasPrefix(columnType, ExtendedTypeTag) {
  1524  		columnType = columnType[len(ExtendedTypeTag):]
  1525  		// If the pipe character "|" is present, then we ignore all information after it (including the pipe), as it
  1526  		// represents a comment
  1527  		if pipeIdx := strings.Index(columnType, "|"); pipeIdx != -1 {
  1528  			columnType = columnType[:pipeIdx]
  1529  		}
  1530  		c, err := types.DeserializeTypeFromString(columnType)
  1531  		if err != nil {
  1532  			return nil, err
  1533  		}
  1534  		return c, nil
  1535  	}
  1536  	parsed, err := ast.Parse(fmt.Sprintf("create table t(a %s)", columnType))
  1537  	if err != nil {
  1538  		return nil, err
  1539  	}
  1540  	ddl, ok := parsed.(*ast.DDL)
  1541  	if !ok {
  1542  		return nil, fmt.Errorf("failed to parse type info for column: %s", columnType)
  1543  	}
  1544  	parsedTyp := ddl.TableSpec.Columns[0].Type
  1545  	typ, err := types.ColumnTypeToType(&parsedTyp)
  1546  	if err != nil {
  1547  		return nil, err
  1548  	}
  1549  	if parsedTyp.SRID != nil {
  1550  		sridVal, err := strconv.ParseInt(string(parsedTyp.SRID.Val), 10, 32)
  1551  		if err != nil {
  1552  			return nil, err
  1553  		}
  1554  
  1555  		if err = types.ValidateSRID(int(sridVal), ""); err != nil {
  1556  			return nil, err
  1557  		}
  1558  		if s, ok := typ.(sql.SpatialColumnType); ok {
  1559  			typ = s.SetSRID(uint32(sridVal))
  1560  		} else {
  1561  			return nil, sql.ErrInvalidType.New(fmt.Sprintf("cannot define SRID for %s", typ))
  1562  		}
  1563  	}
  1564  	return typ, nil
  1565  }
  1566  
  1567  var _ sql.Database = dummyDb{}
  1568  
  1569  type dummyDb struct {
  1570  	name string
  1571  }
  1572  
  1573  func (d dummyDb) Name() string                 { return d.name }
  1574  func (d dummyDb) Tables() map[string]sql.Table { return nil }
  1575  func (d dummyDb) GetTableInsensitive(ctx *sql.Context, tblName string) (sql.Table, bool, error) {
  1576  	return nil, false, nil
  1577  }
  1578  func (d dummyDb) GetTableNames(ctx *sql.Context) ([]string, error) {
  1579  	return nil, nil
  1580  }