github.com/dolthub/go-mysql-server@v0.18.0/sql/rowexec/dml.go (about)

     1  // Copyright 2023 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package rowexec
    16  
    17  import (
    18  	"fmt"
    19  	"sync"
    20  
    21  	"github.com/dolthub/vitess/go/mysql"
    22  
    23  	"github.com/dolthub/go-mysql-server/sql"
    24  	"github.com/dolthub/go-mysql-server/sql/fulltext"
    25  	"github.com/dolthub/go-mysql-server/sql/plan"
    26  	"github.com/dolthub/go-mysql-server/sql/transform"
    27  	"github.com/dolthub/go-mysql-server/sql/types"
    28  )
    29  
    30  func (b *BaseBuilder) buildInsertInto(ctx *sql.Context, ii *plan.InsertInto, row sql.Row) (sql.RowIter, error) {
    31  	dstSchema := ii.Destination.Schema()
    32  
    33  	insertable, err := plan.GetInsertable(ii.Destination)
    34  	if err != nil {
    35  		return nil, err
    36  	}
    37  
    38  	var inserter sql.RowInserter
    39  
    40  	var replacer sql.RowReplacer
    41  	var updater sql.RowUpdater
    42  	// These type casts have already been asserted in the analyzer
    43  	if ii.IsReplace {
    44  		replacer = insertable.(sql.ReplaceableTable).Replacer(ctx)
    45  	} else {
    46  		inserter = insertable.Inserter(ctx)
    47  		if len(ii.OnDupExprs) > 0 {
    48  			updater = insertable.(sql.UpdatableTable).Updater(ctx)
    49  		}
    50  	}
    51  
    52  	rowIter, err := b.buildNodeExec(ctx, ii.Source, row)
    53  	if err != nil {
    54  		return nil, err
    55  	}
    56  
    57  	insertExpressions := getInsertExpressions(ii.Source)
    58  	insertIter := &insertIter{
    59  		schema:              dstSchema,
    60  		tableNode:           ii.Destination,
    61  		inserter:            inserter,
    62  		replacer:            replacer,
    63  		updater:             updater,
    64  		rowSource:           rowIter,
    65  		hasAutoAutoIncValue: ii.HasUnspecifiedAutoInc,
    66  		updateExprs:         ii.OnDupExprs,
    67  		insertExprs:         insertExpressions,
    68  		checks:              ii.Checks(),
    69  		ctx:                 ctx,
    70  		ignore:              ii.Ignore,
    71  	}
    72  
    73  	var ed sql.EditOpenerCloser
    74  	if replacer != nil {
    75  		ed = replacer
    76  	} else {
    77  		ed = inserter
    78  	}
    79  
    80  	if ii.Ignore {
    81  		return plan.NewCheckpointingTableEditorIter(insertIter, ed), nil
    82  	} else {
    83  		return plan.NewTableEditorIter(insertIter, ed), nil
    84  	}
    85  }
    86  
    87  func (b *BaseBuilder) buildDeleteFrom(ctx *sql.Context, n *plan.DeleteFrom, row sql.Row) (sql.RowIter, error) {
    88  	iter, err := b.buildNodeExec(ctx, n.Child, row)
    89  	if err != nil {
    90  		return nil, err
    91  	}
    92  
    93  	targets := n.GetDeleteTargets()
    94  	schemaPositionDeleters := make([]schemaPositionDeleter, len(targets))
    95  	schema := n.Child.Schema()
    96  
    97  	for i, target := range targets {
    98  		deletable, err := plan.GetDeletable(target)
    99  		if err != nil {
   100  			return nil, err
   101  		}
   102  		deleter := deletable.Deleter(ctx)
   103  
   104  		// By default the sourceName in the schema is the table name, but if there is a
   105  		// table alias applied, then use that instead.
   106  		sourceName := deletable.Name()
   107  		transform.Inspect(target, func(node sql.Node) bool {
   108  			if tableAlias, ok := node.(*plan.TableAlias); ok {
   109  				sourceName = tableAlias.Name()
   110  				return false
   111  			}
   112  			return true
   113  		})
   114  
   115  		start, end, err := findSourcePosition(schema, sourceName)
   116  		if err != nil {
   117  			return nil, err
   118  		}
   119  		schemaPositionDeleters[i] = schemaPositionDeleter{deleter, int(start), int(end)}
   120  	}
   121  	return newDeleteIter(iter, schema, schemaPositionDeleters...), nil
   122  }
   123  
   124  func (b *BaseBuilder) buildForeignKeyHandler(ctx *sql.Context, n *plan.ForeignKeyHandler, row sql.Row) (sql.RowIter, error) {
   125  	return b.buildNodeExec(ctx, n.OriginalNode, row)
   126  }
   127  
   128  func (b *BaseBuilder) buildUpdate(ctx *sql.Context, n *plan.Update, row sql.Row) (sql.RowIter, error) {
   129  	updatable, err := plan.GetUpdatable(n.Child)
   130  	if err != nil {
   131  		return nil, err
   132  	}
   133  	updater := updatable.Updater(ctx)
   134  
   135  	iter, err := b.buildNodeExec(ctx, n.Child, row)
   136  	if err != nil {
   137  		return nil, err
   138  	}
   139  
   140  	return newUpdateIter(iter, updatable.Schema(), updater, n.Checks(), n.Ignore), nil
   141  }
   142  
   143  func (b *BaseBuilder) buildDropForeignKey(ctx *sql.Context, n *plan.DropForeignKey, row sql.Row) (sql.RowIter, error) {
   144  	db, err := n.DbProvider.Database(ctx, n.Database())
   145  	if err != nil {
   146  		return nil, err
   147  	}
   148  	tbl, ok, err := db.GetTableInsensitive(ctx, n.Table)
   149  	if err != nil {
   150  		return nil, err
   151  	}
   152  	if !ok {
   153  		return nil, sql.ErrTableNotFound.New(n.Table)
   154  	}
   155  	fkTbl, ok := tbl.(sql.ForeignKeyTable)
   156  	if !ok {
   157  		return nil, sql.ErrNoForeignKeySupport.New(n.Name)
   158  	}
   159  	err = fkTbl.DropForeignKey(ctx, n.Name)
   160  	if err != nil {
   161  		return nil, err
   162  	}
   163  
   164  	return rowIterWithOkResultWithZeroRowsAffected(), nil
   165  }
   166  
   167  func (b *BaseBuilder) buildDropTable(ctx *sql.Context, n *plan.DropTable, row sql.Row) (sql.RowIter, error) {
   168  	var err error
   169  	var curdb sql.Database
   170  
   171  	for _, table := range n.Tables {
   172  		tbl := table.(*plan.ResolvedTable)
   173  		curdb = tbl.SqlDatabase
   174  
   175  		droppable := tbl.SqlDatabase.(sql.TableDropper)
   176  
   177  		if fkTable, err := getForeignKeyTable(tbl); err == nil {
   178  			fkChecks, err := ctx.GetSessionVariable(ctx, "foreign_key_checks")
   179  			if err != nil {
   180  				return nil, err
   181  			}
   182  			if fkChecks.(int8) == 1 {
   183  				parentFks, err := fkTable.GetReferencedForeignKeys(ctx)
   184  				if err != nil {
   185  					return nil, err
   186  				}
   187  				for i, fk := range parentFks {
   188  					// ignore self referential foreign keys
   189  					if fk.Table != fk.ParentTable {
   190  						return nil, sql.ErrForeignKeyDropTable.New(fkTable.Name(), parentFks[i].Name)
   191  					}
   192  				}
   193  			}
   194  			fks, err := fkTable.GetDeclaredForeignKeys(ctx)
   195  			if err != nil {
   196  				return nil, err
   197  			}
   198  			for _, fk := range fks {
   199  				if err = fkTable.DropForeignKey(ctx, fk.Name); err != nil {
   200  					return nil, err
   201  				}
   202  			}
   203  		}
   204  
   205  		if hasFullText(ctx, tbl) {
   206  			if err = fulltext.DropAllIndexes(ctx, tbl.Table.(sql.IndexAddressableTable), droppable.(fulltext.Database)); err != nil {
   207  				return nil, err
   208  			}
   209  		}
   210  
   211  		err = droppable.DropTable(ctx, tbl.Name())
   212  		if err != nil {
   213  			return nil, err
   214  		}
   215  	}
   216  
   217  	if len(n.TriggerNames) > 0 {
   218  		triggerDb, ok := curdb.(sql.TriggerDatabase)
   219  		if !ok {
   220  			tblNames, _ := n.TableNames()
   221  			return nil, fmt.Errorf(`tables %v are referenced in triggers %v, but database does not support triggers`, tblNames, n.TriggerNames)
   222  		}
   223  		//TODO: if dropping any triggers fail, then we'll be left in a state where triggers exist for a table that was dropped
   224  		for _, trigger := range n.TriggerNames {
   225  			err = triggerDb.DropTrigger(ctx, trigger)
   226  			if err != nil {
   227  				return nil, err
   228  			}
   229  		}
   230  	}
   231  
   232  	return rowIterWithOkResultWithZeroRowsAffected(), nil
   233  }
   234  
   235  func (b *BaseBuilder) buildTriggerRollback(ctx *sql.Context, n *plan.TriggerRollback, row sql.Row) (sql.RowIter, error) {
   236  	childIter, err := b.buildNodeExec(ctx, n.Child, row)
   237  	if err != nil {
   238  		return nil, err
   239  	}
   240  
   241  	ctx.GetLogger().Tracef("TriggerRollback creating savepoint: %s", SavePointName)
   242  
   243  	ts, ok := ctx.Session.(sql.TransactionSession)
   244  	if !ok {
   245  		return nil, fmt.Errorf("expected a sql.TransactionSession, but got %T", ctx.Session)
   246  	}
   247  
   248  	if err := ts.CreateSavepoint(ctx, ctx.GetTransaction(), SavePointName); err != nil {
   249  		ctx.GetLogger().WithError(err).Errorf("CreateSavepoint failed")
   250  	}
   251  
   252  	return &triggerRollbackIter{
   253  		child:        childIter,
   254  		hasSavepoint: true,
   255  	}, nil
   256  }
   257  
   258  func (b *BaseBuilder) buildAlterIndex(ctx *sql.Context, n *plan.AlterIndex, row sql.Row) (sql.RowIter, error) {
   259  	err := b.executeAlterIndex(ctx, n)
   260  	if err != nil {
   261  		return nil, err
   262  	}
   263  
   264  	return rowIterWithOkResultWithZeroRowsAffected(), nil
   265  }
   266  
   267  func (b *BaseBuilder) buildTriggerBeginEndBlock(ctx *sql.Context, n *plan.TriggerBeginEndBlock, row sql.Row) (sql.RowIter, error) {
   268  	return &triggerBlockIter{
   269  		statements: n.Children(),
   270  		row:        row,
   271  		once:       &sync.Once{},
   272  	}, nil
   273  }
   274  
   275  func (b *BaseBuilder) buildTriggerExecutor(ctx *sql.Context, n *plan.TriggerExecutor, row sql.Row) (sql.RowIter, error) {
   276  	childIter, err := b.buildNodeExec(ctx, n.Left(), row)
   277  	if err != nil {
   278  		return nil, err
   279  	}
   280  
   281  	return &triggerIter{
   282  		child:          childIter,
   283  		triggerTime:    n.TriggerTime,
   284  		triggerEvent:   n.TriggerEvent,
   285  		executionLogic: n.Right(),
   286  		ctx:            ctx,
   287  	}, nil
   288  }
   289  
   290  func (b *BaseBuilder) buildInsertDestination(ctx *sql.Context, n *plan.InsertDestination, row sql.Row) (sql.RowIter, error) {
   291  	return b.buildNodeExec(ctx, n.Child, row)
   292  }
   293  
   294  func (b *BaseBuilder) buildRowUpdateAccumulator(ctx *sql.Context, n *plan.RowUpdateAccumulator, row sql.Row) (sql.RowIter, error) {
   295  	rowIter, err := b.buildNodeExec(ctx, n.Child(), row)
   296  	if err != nil {
   297  		return nil, err
   298  	}
   299  
   300  	clientFoundRowsToggled := (ctx.Client().Capabilities & mysql.CapabilityClientFoundRows) == mysql.CapabilityClientFoundRows
   301  
   302  	var rowHandler accumulatorRowHandler
   303  	switch n.RowUpdateType {
   304  	case plan.UpdateTypeInsert:
   305  		insertItr, err := findInsertIter(rowIter)
   306  		if err != nil {
   307  			return nil, err
   308  		}
   309  
   310  		rowHandler = &insertRowHandler{
   311  			lastInsertIdGetter: insertItr.getAutoIncVal,
   312  		}
   313  		// TODO: some of these other row handlers also need to keep track of the last insert id
   314  	case plan.UpdateTypeReplace:
   315  		rowHandler = &replaceRowHandler{}
   316  	case plan.UpdateTypeDuplicateKeyUpdate:
   317  		rowHandler = &onDuplicateUpdateHandler{schema: n.Child().Schema(), clientFoundRowsCapability: clientFoundRowsToggled}
   318  	case plan.UpdateTypeUpdate:
   319  		schema := n.Child().Schema()
   320  		// the schema of the update node is a self-concatenation of the underlying table's, so split it in half for new /
   321  		// old row comparison purposes
   322  		rowHandler = &updateRowHandler{schema: schema[:len(schema)/2], clientFoundRowsCapability: clientFoundRowsToggled}
   323  	case plan.UpdateTypeDelete:
   324  		rowHandler = &deleteRowHandler{}
   325  	case plan.UpdateTypeJoinUpdate:
   326  		var schema sql.Schema
   327  		var updaterMap map[string]sql.RowUpdater
   328  		transform.Inspect(n.Child(), func(node sql.Node) bool {
   329  			switch node.(type) {
   330  			case *plan.JoinNode, *plan.Project:
   331  				schema = node.Schema()
   332  				return false
   333  			case *plan.UpdateJoin:
   334  				updaterMap = node.(*plan.UpdateJoin).Updaters
   335  				return true
   336  			}
   337  
   338  			return true
   339  		})
   340  
   341  		if schema == nil {
   342  			return nil, fmt.Errorf("error: No JoinNode found in query plan to go along with an UpdateTypeJoinUpdate")
   343  		}
   344  
   345  		rowHandler = &updateJoinRowHandler{joinSchema: schema, tableMap: plan.RecreateTableSchemaFromJoinSchema(schema), updaterMap: updaterMap}
   346  	default:
   347  		panic(fmt.Sprintf("Unrecognized RowUpdateType %d", n.RowUpdateType))
   348  	}
   349  
   350  	return &accumulatorIter{
   351  		iter:             rowIter,
   352  		updateRowHandler: rowHandler,
   353  	}, nil
   354  }
   355  
   356  func findInsertIter(rowIter sql.RowIter) (*insertIter, error) {
   357  	var insertItr *insertIter
   358  	switch rowIter := rowIter.(type) {
   359  	case *plan.TableEditorIter:
   360  		var ok bool
   361  		insertItr, ok = rowIter.InnerIter().(*insertIter)
   362  		if !ok {
   363  			return nil, fmt.Errorf("unexpected iter type %T", rowIter)
   364  		}
   365  	case *plan.CheckpointingTableEditorIter:
   366  		var ok bool
   367  		insertItr, ok = rowIter.InnerIter().(*insertIter)
   368  		if !ok {
   369  			return nil, fmt.Errorf("unexpected iter type %T", rowIter)
   370  		}
   371  	case *triggerIter:
   372  		var err error
   373  		insertItr, err = findInsertIter(rowIter.child)
   374  		if err != nil {
   375  			return nil, err
   376  		}
   377  	default:
   378  		return nil, fmt.Errorf("unexpected iter type %T", rowIter)
   379  	}
   380  	return insertItr, nil
   381  }
   382  
   383  func (b *BaseBuilder) buildTruncate(ctx *sql.Context, n *plan.Truncate, row sql.Row) (sql.RowIter, error) {
   384  	truncatable, err := plan.GetTruncatable(n.Child)
   385  	if err != nil {
   386  		return nil, err
   387  	}
   388  	//TODO: when performance schema summary tables are added, reset the columns to 0/NULL rather than remove rows
   389  	//TODO: close all handlers that were opened with "HANDLER OPEN"
   390  
   391  	removed, err := truncatable.Truncate(ctx)
   392  	if err != nil {
   393  		return nil, err
   394  	}
   395  	for _, col := range truncatable.Schema() {
   396  		if col.AutoIncrement {
   397  			aiTable, ok := truncatable.(sql.AutoIncrementTable)
   398  			if ok {
   399  				setter := aiTable.AutoIncrementSetter(ctx)
   400  				err = setter.SetAutoIncrementValue(ctx, uint64(1))
   401  				if err != nil {
   402  					return nil, err
   403  				}
   404  				err = setter.Close(ctx)
   405  				if err != nil {
   406  					return nil, err
   407  				}
   408  			}
   409  			break
   410  		}
   411  	}
   412  	// If we've got Full-Text indexes, then we also need to clear those tables
   413  	if hasFullText(ctx, truncatable) {
   414  		if err = rebuildFullText(ctx, truncatable.Name(), plan.GetDatabase(n.Child)); err != nil {
   415  			return nil, err
   416  		}
   417  	}
   418  	return sql.RowsToRowIter(sql.NewRow(types.NewOkResult(removed))), nil
   419  }
   420  
   421  func (b *BaseBuilder) buildUpdateSource(ctx *sql.Context, n *plan.UpdateSource, row sql.Row) (sql.RowIter, error) {
   422  	rowIter, err := b.buildNodeExec(ctx, n.Child, row)
   423  	if err != nil {
   424  		return nil, err
   425  	}
   426  
   427  	schema, err := n.GetChildSchema()
   428  	if err != nil {
   429  		return nil, err
   430  	}
   431  
   432  	return &updateSourceIter{
   433  		childIter:   rowIter,
   434  		updateExprs: n.UpdateExprs,
   435  		tableSchema: schema,
   436  		ignore:      n.Ignore,
   437  	}, nil
   438  }
   439  
   440  func (b *BaseBuilder) buildUpdateJoin(ctx *sql.Context, n *plan.UpdateJoin, row sql.Row) (sql.RowIter, error) {
   441  	ji, err := b.buildNodeExec(ctx, n.Child, row)
   442  	if err != nil {
   443  		return nil, err
   444  	}
   445  
   446  	return &updateJoinIter{
   447  		updateSourceIter: ji,
   448  		joinSchema:       n.Child.(*plan.UpdateSource).Child.Schema(),
   449  		updaters:         n.Updaters,
   450  		caches:           make(map[string]sql.KeyValueCache),
   451  		disposals:        make(map[string]sql.DisposeFunc),
   452  		joinNode:         n.Child.(*plan.UpdateSource).Child,
   453  	}, nil
   454  }