github.com/dolthub/go-mysql-server@v0.18.0/sql/planbuilder/dml.go (about)

     1  // Copyright 2023 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package planbuilder
    16  
    17  import (
    18  	"fmt"
    19  	"strings"
    20  
    21  	ast "github.com/dolthub/vitess/go/vt/sqlparser"
    22  
    23  	"github.com/dolthub/go-mysql-server/sql"
    24  	"github.com/dolthub/go-mysql-server/sql/expression"
    25  	"github.com/dolthub/go-mysql-server/sql/plan"
    26  	"github.com/dolthub/go-mysql-server/sql/transform"
    27  )
    28  
    29  const OnDupValuesPrefix = "__new_ins"
    30  
    31  func (b *Builder) buildInsert(inScope *scope, i *ast.Insert) (outScope *scope) {
    32  	if i.With != nil {
    33  		inScope = b.buildWith(inScope, i.With)
    34  	}
    35  	dbName := i.Table.Qualifier.String()
    36  	tableName := i.Table.Name.String()
    37  	destScope, ok := b.buildResolvedTable(inScope, dbName, tableName, nil)
    38  	if !ok {
    39  		b.handleErr(sql.ErrTableNotFound.New(tableName))
    40  	}
    41  	var db sql.Database
    42  	var rt *plan.ResolvedTable
    43  	switch n := destScope.node.(type) {
    44  	case *plan.ResolvedTable:
    45  		rt = n
    46  		db = rt.SqlDatabase
    47  	case *plan.UnresolvedTable:
    48  		db = n.Database()
    49  	default:
    50  		b.handleErr(fmt.Errorf("expected insert destination to be resolved or unresolved table"))
    51  	}
    52  	if rt == nil {
    53  		if b.TriggerCtx().Active && !b.TriggerCtx().Call {
    54  			b.TriggerCtx().UnresolvedTables = append(b.TriggerCtx().UnresolvedTables, tableName)
    55  		} else {
    56  			err := fmt.Errorf("expected resolved table: %s", tableName)
    57  			b.handleErr(err)
    58  		}
    59  	}
    60  	isReplace := i.Action == ast.ReplaceStr
    61  
    62  	var columns []string
    63  	{
    64  		columns = columnsToStrings(i.Columns)
    65  		// If no column names were specified in the query, go ahead and fill
    66  		// them all in now that the destination is resolved.
    67  		// TODO: setting the plan field directly is not great
    68  		if len(columns) == 0 && len(destScope.cols) > 0 && rt != nil {
    69  			schema := rt.Schema()
    70  			columns = make([]string, len(schema))
    71  			for i, col := range schema {
    72  				// Tables with any generated column must always supply a column list, so this is always an error
    73  				if col.Generated != nil {
    74  					b.handleErr(sql.ErrGeneratedColumnValue.New(col.Name, rt.Name()))
    75  				}
    76  				columns[i] = col.Name
    77  			}
    78  		}
    79  	}
    80  	sch := destScope.node.Schema()
    81  	if rt != nil {
    82  		sch = b.resolveSchemaDefaults(destScope, rt.Schema())
    83  	}
    84  	srcScope := b.insertRowsToNode(inScope, i.Rows, columns, tableName, sch)
    85  
    86  	// TODO: on duplicate expressions need to reference both VALUES and
    87  	//  derived columns equally in ON DUPLICATE UPDATE expressions.
    88  	combinedScope := inScope.replace()
    89  	for i, c := range destScope.cols {
    90  		combinedScope.newColumn(c)
    91  		if len(srcScope.cols) == len(destScope.cols) {
    92  			combinedScope.newColumn(srcScope.cols[i])
    93  		} else {
    94  			// check for VALUES refs
    95  			c.table = OnDupValuesPrefix
    96  			combinedScope.newColumn(c)
    97  		}
    98  	}
    99  	onDupExprs := b.buildOnDupUpdateExprs(combinedScope, destScope, ast.AssignmentExprs(i.OnDup))
   100  
   101  	ignore := false
   102  	// TODO: make this a bool in vitess
   103  	if strings.Contains(strings.ToLower(i.Ignore), "ignore") {
   104  		ignore = true
   105  	}
   106  
   107  	dest := destScope.node
   108  
   109  	ins := plan.NewInsertInto(db, plan.NewInsertDestination(sch, dest), srcScope.node, isReplace, columns, onDupExprs, ignore)
   110  
   111  	b.validateInsert(ins)
   112  
   113  	outScope = destScope
   114  	outScope.node = ins
   115  	if rt != nil {
   116  		checks := b.loadChecksFromTable(destScope, rt.Table)
   117  		outScope.node = ins.WithChecks(checks)
   118  	}
   119  
   120  	return
   121  }
   122  
   123  func (b *Builder) insertRowsToNode(inScope *scope, ir ast.InsertRows, columnNames []string, tableName string, destSchema sql.Schema) (outScope *scope) {
   124  	switch v := ir.(type) {
   125  	case ast.SelectStatement:
   126  		return b.buildSelectStmt(inScope, v)
   127  	case ast.Values:
   128  		outScope = b.buildInsertValues(inScope, v, columnNames, tableName, destSchema)
   129  	default:
   130  		err := sql.ErrUnsupportedSyntax.New(ast.String(ir))
   131  		b.handleErr(err)
   132  	}
   133  	return
   134  }
   135  
   136  func (b *Builder) buildInsertValues(inScope *scope, v ast.Values, columnNames []string, tableName string, destSchema sql.Schema) (outScope *scope) {
   137  	columnDefaultValues := make([]*sql.ColumnDefaultValue, len(columnNames))
   138  
   139  	for i, columnName := range columnNames {
   140  		index := destSchema.IndexOfColName(columnName)
   141  		if index == -1 {
   142  			if !b.TriggerCtx().Call && len(b.TriggerCtx().UnresolvedTables) > 0 {
   143  				continue
   144  			}
   145  			err := sql.ErrUnknownColumn.New(columnName, tableName)
   146  			b.handleErr(err)
   147  		}
   148  
   149  		columnDefaultValues[i] = destSchema[index].Default
   150  		if columnDefaultValues[i] == nil && destSchema[index].Generated != nil {
   151  			columnDefaultValues[i] = destSchema[index].Generated
   152  		}
   153  	}
   154  
   155  	exprTuples := make([][]sql.Expression, len(v))
   156  	for i, vt := range v {
   157  		// noExprs is an edge case where we fill VALUES with nil expressions
   158  		noExprs := len(vt) == 0
   159  		// triggerUnknownTable is an edge case where we ignored an unresolved
   160  		// table error and do not have a schema for resolving defaults
   161  		triggerUnknownTable := (len(columnNames) == 0 && len(vt) > 0) && (len(b.TriggerCtx().UnresolvedTables) > 0)
   162  
   163  		if len(vt) != len(columnNames) && !noExprs && !triggerUnknownTable {
   164  			err := sql.ErrInsertIntoMismatchValueCount.New()
   165  			b.handleErr(err)
   166  		}
   167  		exprs := make([]sql.Expression, len(columnNames))
   168  		exprTuples[i] = exprs
   169  		for j := range columnNames {
   170  			if noExprs || triggerUnknownTable {
   171  				exprs[j] = expression.WrapExpression(columnDefaultValues[j])
   172  				continue
   173  			}
   174  			e := vt[j]
   175  			switch e := e.(type) {
   176  			case *ast.Default:
   177  				exprs[j] = expression.WrapExpression(columnDefaultValues[j])
   178  				// explicit DEFAULT values need their column indexes assigned early, since we analyze the insert values in
   179  				// isolation (no access to the destination schema)
   180  				exprs[j] = assignColumnIndexes(exprs[j], reorderSchema(columnNames, destSchema))
   181  			case *ast.SQLVal:
   182  				// In the case of an unknown bindvar, give it a target type of the column it's targeting.
   183  				// We only do this for simple bindvars in tuples, not expressions that contain bindvars.
   184  				if b.shouldAssignBindvarType(e) {
   185  					name := strings.TrimPrefix(string(e.Val), ":")
   186  					bindVar := expression.NewBindVar(name)
   187  					bindVar.Typ = reorderSchema(columnNames, destSchema)[j].Type
   188  					exprs[j] = bindVar
   189  				} else {
   190  					exprs[j] = b.buildScalar(inScope, e)
   191  				}
   192  			default:
   193  				exprs[j] = b.buildScalar(inScope, e)
   194  			}
   195  		}
   196  	}
   197  
   198  	outScope = inScope.push()
   199  	outScope.node = plan.NewValues(exprTuples)
   200  	return
   201  }
   202  
   203  func (b *Builder) shouldAssignBindvarType(e *ast.SQLVal) bool {
   204  	return e.Type == ast.ValArg && (b.bindCtx == nil || b.bindCtx.resolveOnly)
   205  }
   206  
   207  // reorderSchema returns the schemas columns in the order specified by names
   208  func reorderSchema(names []string, schema sql.Schema) sql.Schema {
   209  	newSch := make(sql.Schema, len(names))
   210  	for i, name := range names {
   211  		newSch[i] = schema[schema.IndexOfColName(name)]
   212  	}
   213  	return newSch
   214  }
   215  
   216  func (b *Builder) buildValues(inScope *scope, v ast.Values) (outScope *scope) {
   217  	// TODO add literals to outScope?
   218  	exprTuples := make([][]sql.Expression, len(v))
   219  	for i, vt := range v {
   220  		exprs := make([]sql.Expression, len(vt))
   221  		exprTuples[i] = exprs
   222  		for j, e := range vt {
   223  			exprs[j] = b.buildScalar(inScope, e)
   224  		}
   225  	}
   226  
   227  	outScope = inScope.push()
   228  	outScope.node = plan.NewValues(exprTuples)
   229  	return
   230  }
   231  
   232  func (b *Builder) assignmentExprsToExpressions(inScope *scope, e ast.AssignmentExprs) []sql.Expression {
   233  	updateExprs := make([]sql.Expression, len(e))
   234  	var startAggCnt int
   235  	if inScope.groupBy != nil {
   236  		startAggCnt = len(inScope.groupBy.aggs)
   237  	}
   238  	var startWinCnt int
   239  	if inScope.windowFuncs != nil {
   240  		startWinCnt = len(inScope.windowFuncs)
   241  	}
   242  
   243  	tableSch := b.resolveSchemaDefaults(inScope, inScope.node.Schema())
   244  
   245  	for i, updateExpr := range e {
   246  		colName := b.buildScalar(inScope, updateExpr.Name)
   247  
   248  		innerExpr := b.buildScalar(inScope, updateExpr.Expr)
   249  		if gf, ok := colName.(*expression.GetField); ok {
   250  			colIdx := tableSch.IndexOfColName(gf.Name())
   251  			// TODO: during trigger parsing the table in the node is unresolved, so we need this additional bounds check
   252  			//  This means that trigger execution will be able to update generated columns
   253  			// Prevent update of generated columns
   254  			if colIdx >= 0 && tableSch[colIdx].Generated != nil {
   255  				err := sql.ErrGeneratedColumnValue.New(tableSch[colIdx].Name, inScope.node.(sql.NameableNode).Name())
   256  				b.handleErr(err)
   257  			}
   258  
   259  			// Replace default with column default from resolved schema
   260  			if _, ok := updateExpr.Expr.(*ast.Default); ok {
   261  				if colIdx >= 0 {
   262  					innerExpr = expression.WrapExpression(tableSch[colIdx].Default)
   263  				}
   264  			}
   265  		}
   266  
   267  		// In the case of an unknown bindvar, give it a target type of the column it's targeting.
   268  		// We only do this for simple bindvars in tuples, not expressions that contain bindvars.
   269  		if innerSqlVal, ok := updateExpr.Expr.(*ast.SQLVal); ok && b.shouldAssignBindvarType(innerSqlVal) {
   270  			if typ, ok := hasColumnType(colName); ok {
   271  				rightBindVar := innerExpr.(*expression.BindVar)
   272  				rightBindVar.Typ = typ
   273  				innerExpr = rightBindVar
   274  			}
   275  		}
   276  
   277  		updateExprs[i] = expression.NewSetField(colName, innerExpr)
   278  		if inScope.groupBy != nil {
   279  			if len(inScope.groupBy.aggs) > startAggCnt {
   280  				err := sql.ErrAggregationUnsupported.New(updateExprs[i])
   281  				b.handleErr(err)
   282  			}
   283  		}
   284  		if inScope.windowFuncs != nil {
   285  			if len(inScope.windowFuncs) > startWinCnt {
   286  				err := sql.ErrWindowUnsupported.New(updateExprs[i])
   287  				b.handleErr(err)
   288  			}
   289  		}
   290  	}
   291  
   292  	// We need additional update expressions for any generated columns and on update expressions, since they won't be part of the update
   293  	// expressions, but their value in the row must be updated before being passed to the integrator for storage.
   294  	if len(tableSch) > 0 {
   295  		tabId := inScope.tables[strings.ToLower(tableSch[0].Source)]
   296  		for i, col := range tableSch {
   297  			if col.Generated != nil {
   298  				colGf := expression.NewGetFieldWithTable(i+1, int(tabId), col.Type, col.DatabaseSource, col.Source, col.Name, col.Nullable)
   299  				generated := b.resolveColumnDefaultExpression(inScope, col, col.Generated)
   300  				updateExprs = append(updateExprs, expression.NewSetField(colGf, assignColumnIndexes(generated, tableSch)))
   301  			}
   302  			if col.OnUpdate != nil {
   303  				// don't add if column is already being updated
   304  				if !isColumnUpdated(col, updateExprs) {
   305  					colGf := expression.NewGetFieldWithTable(i+1, int(tabId), col.Type, col.DatabaseSource, col.Source, col.Name, col.Nullable)
   306  					onUpdate := b.resolveColumnDefaultExpression(inScope, col, col.OnUpdate)
   307  					updateExprs = append(updateExprs, expression.NewSetField(colGf, assignColumnIndexes(onUpdate, tableSch)))
   308  				}
   309  			}
   310  		}
   311  	}
   312  
   313  	return updateExprs
   314  }
   315  
   316  func isColumnUpdated(col *sql.Column, updateExprs []sql.Expression) bool {
   317  	for _, expr := range updateExprs {
   318  		sf, ok := expr.(*expression.SetField)
   319  		if !ok {
   320  			continue
   321  		}
   322  		gf, ok := sf.LeftChild.(*expression.GetField)
   323  		if !ok {
   324  			continue
   325  		}
   326  		if strings.EqualFold(gf.Name(), col.Name) {
   327  			return true
   328  		}
   329  	}
   330  	return false
   331  }
   332  
   333  func (b *Builder) buildOnDupUpdateExprs(combinedScope, destScope *scope, e ast.AssignmentExprs) []sql.Expression {
   334  	b.insertActive = true
   335  	defer func() {
   336  		b.insertActive = false
   337  	}()
   338  	res := make([]sql.Expression, len(e))
   339  	// todo(max): prevent aggregations in separate semantic walk step
   340  	var startAggCnt int
   341  	if combinedScope.groupBy != nil {
   342  		startAggCnt = len(combinedScope.groupBy.aggs)
   343  	}
   344  	var startWinCnt int
   345  	if combinedScope.windowFuncs != nil {
   346  		startWinCnt = len(combinedScope.windowFuncs)
   347  	}
   348  	for i, updateExpr := range e {
   349  		colName := b.buildOnDupLeft(destScope, updateExpr.Name)
   350  		innerExpr := b.buildScalar(combinedScope, updateExpr.Expr)
   351  
   352  		res[i] = expression.NewSetField(colName, innerExpr)
   353  		if combinedScope.groupBy != nil {
   354  			if len(combinedScope.groupBy.aggs) > startAggCnt {
   355  				err := sql.ErrAggregationUnsupported.New(res[i])
   356  				b.handleErr(err)
   357  			}
   358  		}
   359  		if combinedScope.windowFuncs != nil {
   360  			if len(combinedScope.windowFuncs) > startWinCnt {
   361  				err := sql.ErrWindowUnsupported.New(res[i])
   362  				b.handleErr(err)
   363  			}
   364  		}
   365  	}
   366  	return res
   367  }
   368  
   369  func (b *Builder) buildOnDupLeft(inScope *scope, e ast.Expr) sql.Expression {
   370  	// expect col reference only
   371  	switch e := e.(type) {
   372  	case *ast.ColName:
   373  		dbName := strings.ToLower(e.Qualifier.Qualifier.String())
   374  		tblName := strings.ToLower(e.Qualifier.Name.String())
   375  		colName := strings.ToLower(e.Name.String())
   376  		c, ok := inScope.resolveColumn(dbName, tblName, colName, true, false)
   377  		if !ok {
   378  			if tblName != "" && !inScope.hasTable(tblName) {
   379  				b.handleErr(sql.ErrTableNotFound.New(tblName))
   380  			} else if tblName != "" {
   381  				b.handleErr(sql.ErrTableColumnNotFound.New(tblName, colName))
   382  			}
   383  			b.handleErr(sql.ErrColumnNotFound.New(e))
   384  		}
   385  		return c.scalarGf()
   386  	default:
   387  		err := fmt.Errorf("invalid update target; expected column reference, found: %T", e)
   388  		b.handleErr(err)
   389  	}
   390  	return nil
   391  }
   392  
   393  func (b *Builder) buildDelete(inScope *scope, d *ast.Delete) (outScope *scope) {
   394  	outScope = b.buildFrom(inScope, d.TableExprs)
   395  	b.buildWhere(outScope, d.Where)
   396  	orderByScope := b.analyzeOrderBy(outScope, outScope, d.OrderBy)
   397  	b.buildOrderBy(outScope, orderByScope)
   398  	offset := b.buildOffset(outScope, d.Limit)
   399  	if offset != nil {
   400  		outScope.node = plan.NewOffset(offset, outScope.node)
   401  	}
   402  	limit := b.buildLimit(outScope, d.Limit)
   403  	if limit != nil {
   404  		outScope.node = plan.NewLimit(limit, outScope.node)
   405  	}
   406  
   407  	var targets []sql.Node
   408  	if len(d.Targets) > 0 {
   409  		targets = make([]sql.Node, len(d.Targets))
   410  		for i, tableName := range d.Targets {
   411  			tabName := tableName.Name.String()
   412  			dbName := tableName.Qualifier.String()
   413  			if dbName == "" {
   414  				dbName = b.ctx.GetCurrentDatabase()
   415  			}
   416  			var target sql.Node
   417  			if _, ok := outScope.tables[tabName]; ok {
   418  				transform.InspectUp(outScope.node, func(n sql.Node) bool {
   419  					switch n := n.(type) {
   420  					case sql.NameableNode:
   421  						if strings.EqualFold(n.Name(), tabName) {
   422  							target = n
   423  							return true
   424  						}
   425  					default:
   426  					}
   427  					return false
   428  				})
   429  			} else {
   430  				tableScope, ok := b.buildResolvedTable(inScope, dbName, tabName, nil)
   431  				if !ok {
   432  					b.handleErr(sql.ErrTableNotFound.New(tabName))
   433  				}
   434  				target = tableScope.node
   435  			}
   436  			targets[i] = target
   437  		}
   438  	}
   439  
   440  	del := plan.NewDeleteFrom(outScope.node, targets)
   441  	outScope.node = del
   442  	return
   443  }
   444  
   445  func (b *Builder) buildUpdate(inScope *scope, u *ast.Update) (outScope *scope) {
   446  	outScope = b.buildFrom(inScope, u.TableExprs)
   447  
   448  	// default expressions only resolve to target table
   449  	updateExprs := b.assignmentExprsToExpressions(outScope, u.Exprs)
   450  
   451  	b.buildWhere(outScope, u.Where)
   452  
   453  	orderByScope := b.analyzeOrderBy(outScope, b.newScope(), u.OrderBy)
   454  
   455  	b.buildOrderBy(outScope, orderByScope)
   456  	offset := b.buildOffset(outScope, u.Limit)
   457  	if offset != nil {
   458  		outScope.node = plan.NewOffset(offset, outScope.node)
   459  	}
   460  
   461  	limit := b.buildLimit(outScope, u.Limit)
   462  	if limit != nil {
   463  		outScope.node = plan.NewLimit(limit, outScope.node)
   464  	}
   465  
   466  	// TODO comments
   467  	// If the top level node can store comments and one was provided, store it.
   468  	//if cn, ok := node.(sql.CommentedNode); ok && len(u.Comments) > 0 {
   469  	//	node = cn.WithComment(string(u.Comments[0]))
   470  	//}
   471  
   472  	ignore := u.Ignore != ""
   473  	update := plan.NewUpdate(outScope.node, ignore, updateExprs)
   474  
   475  	var checks []*sql.CheckConstraint
   476  	if join, ok := outScope.node.(*plan.JoinNode); ok {
   477  		source := plan.NewUpdateSource(
   478  			join,
   479  			ignore,
   480  			updateExprs,
   481  		)
   482  		updaters, err := rowUpdatersByTable(b.ctx, source, join)
   483  		if err != nil {
   484  			b.handleErr(err)
   485  		}
   486  		updateJoin := plan.NewUpdateJoin(updaters, source)
   487  		update.Child = updateJoin
   488  		transform.Inspect(update, func(n sql.Node) bool {
   489  			// todo maybe this should be later stage
   490  			switch n := n.(type) {
   491  			case sql.NameableNode:
   492  				if _, ok := updaters[n.Name()]; ok {
   493  					rt := getResolvedTable(n)
   494  					tableScope := inScope.push()
   495  					for _, c := range rt.Schema() {
   496  						tableScope.addColumn(scopeColumn{
   497  							db:       rt.SqlDatabase.Name(),
   498  							table:    strings.ToLower(n.Name()),
   499  							tableId:  tableScope.tables[strings.ToLower(n.Name())],
   500  							col:      strings.ToLower(c.Name),
   501  							typ:      c.Type,
   502  							nullable: c.Nullable,
   503  						})
   504  					}
   505  					checks = append(checks, b.loadChecksFromTable(tableScope, rt.Table)...)
   506  				}
   507  			default:
   508  			}
   509  			return true
   510  		})
   511  	} else {
   512  		transform.Inspect(update, func(n sql.Node) bool {
   513  			// todo maybe this should be later stage
   514  			if rt, ok := n.(*plan.ResolvedTable); ok {
   515  				checks = append(checks, b.loadChecksFromTable(outScope, rt.Table)...)
   516  			}
   517  			return true
   518  		})
   519  	}
   520  	outScope.node = update.WithChecks(checks)
   521  	return
   522  }
   523  
   524  // rowUpdatersByTable maps a set of tables to their RowUpdater objects.
   525  func rowUpdatersByTable(ctx *sql.Context, node sql.Node, ij sql.Node) (map[string]sql.RowUpdater, error) {
   526  	namesOfTableToBeUpdated := getTablesToBeUpdated(node)
   527  	resolvedTables := getTablesByName(ij)
   528  
   529  	rowUpdatersByTable := make(map[string]sql.RowUpdater)
   530  	for tableToBeUpdated, _ := range namesOfTableToBeUpdated {
   531  		resolvedTable, ok := resolvedTables[tableToBeUpdated]
   532  		if !ok {
   533  			return nil, plan.ErrUpdateForTableNotSupported.New(tableToBeUpdated)
   534  		}
   535  
   536  		var table = resolvedTable.UnderlyingTable()
   537  
   538  		// If there is no UpdatableTable for a table being updated, error out
   539  		updatable, ok := table.(sql.UpdatableTable)
   540  		if !ok && updatable == nil {
   541  			return nil, plan.ErrUpdateForTableNotSupported.New(tableToBeUpdated)
   542  		}
   543  
   544  		keyless := sql.IsKeyless(updatable.Schema())
   545  		if keyless {
   546  			return nil, sql.ErrUnsupportedFeature.New("error: keyless tables unsupported for UPDATE JOIN")
   547  		}
   548  
   549  		rowUpdatersByTable[tableToBeUpdated] = updatable.Updater(ctx)
   550  	}
   551  
   552  	return rowUpdatersByTable, nil
   553  }
   554  
   555  // getTablesByName takes a node and returns all found resolved tables in a map.
   556  func getTablesByName(node sql.Node) map[string]*plan.ResolvedTable {
   557  	ret := make(map[string]*plan.ResolvedTable)
   558  
   559  	transform.Inspect(node, func(node sql.Node) bool {
   560  		switch n := node.(type) {
   561  		case *plan.ResolvedTable:
   562  			ret[n.Table.Name()] = n
   563  		case *plan.IndexedTableAccess:
   564  			rt, ok := n.TableNode.(*plan.ResolvedTable)
   565  			if ok {
   566  				ret[rt.Name()] = rt
   567  			}
   568  		case *plan.TableAlias:
   569  			rt := getResolvedTable(n)
   570  			if rt != nil {
   571  				ret[n.Name()] = rt
   572  			}
   573  		default:
   574  		}
   575  		return true
   576  	})
   577  
   578  	return ret
   579  }
   580  
   581  // Finds first TableNode node that is a descendant of the node given
   582  func getResolvedTable(node sql.Node) *plan.ResolvedTable {
   583  	var table *plan.ResolvedTable
   584  	transform.Inspect(node, func(node sql.Node) bool {
   585  		// plan.Inspect will get called on all children of a node even if one of the children's calls returns false. We
   586  		// only want the first TableNode match.
   587  		if table != nil {
   588  			return false
   589  		}
   590  
   591  		switch n := node.(type) {
   592  		case *plan.ResolvedTable:
   593  			if !plan.IsDualTable(n) {
   594  				table = n
   595  				return false
   596  			}
   597  		case *plan.IndexedTableAccess:
   598  			rt, ok := n.TableNode.(*plan.ResolvedTable)
   599  			if ok {
   600  				table = rt
   601  				return false
   602  			}
   603  		}
   604  		return true
   605  	})
   606  	return table
   607  }
   608  
   609  // getTablesToBeUpdated takes a node and looks for the tables to modified by a SetField.
   610  func getTablesToBeUpdated(node sql.Node) map[string]struct{} {
   611  	ret := make(map[string]struct{})
   612  
   613  	transform.InspectExpressions(node, func(e sql.Expression) bool {
   614  		switch e := e.(type) {
   615  		case *expression.SetField:
   616  			gf := e.LeftChild.(*expression.GetField)
   617  			ret[gf.Table()] = struct{}{}
   618  			return false
   619  		}
   620  
   621  		return true
   622  	})
   623  
   624  	return ret
   625  }
   626  
   627  func (b *Builder) buildInto(inScope *scope, into *ast.Into) {
   628  	if into.Dumpfile != "" {
   629  		inScope.node = plan.NewInto(inScope.node, nil, "", into.Dumpfile)
   630  		return
   631  	}
   632  
   633  	if into.Outfile != "" {
   634  		intoNode := plan.NewInto(inScope.node, nil, into.Outfile, "")
   635  
   636  		if into.Charset != "" {
   637  			// TODO: deal with charset; error for now
   638  			intoNode.Charset = into.Charset
   639  			b.handleErr(sql.ErrUnsupportedFeature.New("CHARSET in INTO OUTFILE"))
   640  		}
   641  
   642  		if into.Fields != nil {
   643  			if into.Fields.TerminatedBy != nil && len(into.Fields.TerminatedBy.Val) != 0 {
   644  				intoNode.FieldsTerminatedBy = string(into.Fields.TerminatedBy.Val)
   645  			}
   646  			if into.Fields.EnclosedBy != nil {
   647  				intoNode.FieldsEnclosedBy = string(into.Fields.EnclosedBy.Delim.Val)
   648  				if len(intoNode.FieldsEnclosedBy) > 1 {
   649  					b.handleErr(sql.ErrUnexpectedSeparator.New())
   650  				}
   651  				if into.Fields.EnclosedBy.Optionally {
   652  					intoNode.FieldsEnclosedByOpt = true
   653  				}
   654  			}
   655  			if into.Fields.EscapedBy != nil {
   656  				intoNode.FieldsEscapedBy = string(into.Fields.EscapedBy.Val)
   657  				if len(intoNode.FieldsEscapedBy) > 1 {
   658  					b.handleErr(sql.ErrUnexpectedSeparator.New())
   659  				}
   660  			}
   661  		}
   662  
   663  		if into.Lines != nil {
   664  			if into.Lines.StartingBy != nil {
   665  				intoNode.LinesStartingBy = string(into.Lines.StartingBy.Val)
   666  			}
   667  			if into.Lines.TerminatedBy != nil {
   668  				intoNode.LinesTerminatedBy = string(into.Lines.TerminatedBy.Val)
   669  			}
   670  		}
   671  
   672  		inScope.node = intoNode
   673  		return
   674  	}
   675  
   676  	vars := make([]sql.Expression, len(into.Variables))
   677  	for i, val := range into.Variables {
   678  		if strings.HasPrefix(val.String(), "@") {
   679  			vars[i] = expression.NewUserVar(strings.TrimPrefix(val.String(), "@"))
   680  		} else {
   681  			col, ok := inScope.proc.GetVar(val.String())
   682  			if !ok {
   683  				err := sql.ErrExternalProcedureMissingContextParam.New(val.String())
   684  				b.handleErr(err)
   685  			}
   686  			vars[i] = col.scalarGf()
   687  		}
   688  	}
   689  	inScope.node = plan.NewInto(inScope.node, vars, "", "")
   690  }
   691  
   692  func (b *Builder) loadChecksFromTable(inScope *scope, table sql.Table) []*sql.CheckConstraint {
   693  	var loadedChecks []*sql.CheckConstraint
   694  	if checkTable, ok := table.(sql.CheckTable); ok {
   695  		checks, err := checkTable.GetChecks(b.ctx)
   696  		if err != nil {
   697  			b.handleErr(err)
   698  		}
   699  		for _, ch := range checks {
   700  			constraint := b.buildCheckConstraint(inScope, &ch)
   701  			loadedChecks = append(loadedChecks, constraint)
   702  		}
   703  	}
   704  	return loadedChecks
   705  }
   706  
   707  func (b *Builder) buildCheckConstraint(inScope *scope, check *sql.CheckDefinition) *sql.CheckConstraint {
   708  	parseStr := fmt.Sprintf("select %s", check.CheckExpression)
   709  	parsed, err := ast.Parse(parseStr)
   710  	if err != nil {
   711  		b.handleErr(err)
   712  	}
   713  
   714  	selectStmt, ok := parsed.(*ast.Select)
   715  	if !ok || len(selectStmt.SelectExprs) != 1 {
   716  		err := sql.ErrInvalidCheckConstraint.New(check.CheckExpression)
   717  		b.handleErr(err)
   718  	}
   719  
   720  	expr := selectStmt.SelectExprs[0]
   721  	ae, ok := expr.(*ast.AliasedExpr)
   722  	if !ok {
   723  		err := sql.ErrInvalidCheckConstraint.New(check.CheckExpression)
   724  		b.handleErr(err)
   725  	}
   726  
   727  	c := b.buildScalar(inScope, ae.Expr)
   728  
   729  	return &sql.CheckConstraint{
   730  		Name:     check.Name,
   731  		Expr:     c,
   732  		Enforced: check.Enforced,
   733  	}
   734  }