github.com/dolthub/go-mysql-server@v0.18.0/sql/rowexec/dml.go (about) 1 // Copyright 2023 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package rowexec 16 17 import ( 18 "fmt" 19 "sync" 20 21 "github.com/dolthub/vitess/go/mysql" 22 23 "github.com/dolthub/go-mysql-server/sql" 24 "github.com/dolthub/go-mysql-server/sql/fulltext" 25 "github.com/dolthub/go-mysql-server/sql/plan" 26 "github.com/dolthub/go-mysql-server/sql/transform" 27 "github.com/dolthub/go-mysql-server/sql/types" 28 ) 29 30 func (b *BaseBuilder) buildInsertInto(ctx *sql.Context, ii *plan.InsertInto, row sql.Row) (sql.RowIter, error) { 31 dstSchema := ii.Destination.Schema() 32 33 insertable, err := plan.GetInsertable(ii.Destination) 34 if err != nil { 35 return nil, err 36 } 37 38 var inserter sql.RowInserter 39 40 var replacer sql.RowReplacer 41 var updater sql.RowUpdater 42 // These type casts have already been asserted in the analyzer 43 if ii.IsReplace { 44 replacer = insertable.(sql.ReplaceableTable).Replacer(ctx) 45 } else { 46 inserter = insertable.Inserter(ctx) 47 if len(ii.OnDupExprs) > 0 { 48 updater = insertable.(sql.UpdatableTable).Updater(ctx) 49 } 50 } 51 52 rowIter, err := b.buildNodeExec(ctx, ii.Source, row) 53 if err != nil { 54 return nil, err 55 } 56 57 insertExpressions := getInsertExpressions(ii.Source) 58 insertIter := &insertIter{ 59 schema: dstSchema, 60 tableNode: ii.Destination, 61 inserter: inserter, 62 replacer: replacer, 63 updater: updater, 64 rowSource: rowIter, 65 hasAutoAutoIncValue: ii.HasUnspecifiedAutoInc, 66 updateExprs: ii.OnDupExprs, 67 insertExprs: insertExpressions, 68 checks: ii.Checks(), 69 ctx: ctx, 70 ignore: ii.Ignore, 71 } 72 73 var ed sql.EditOpenerCloser 74 if replacer != nil { 75 ed = replacer 76 } else { 77 ed = inserter 78 } 79 80 if ii.Ignore { 81 return plan.NewCheckpointingTableEditorIter(insertIter, ed), nil 82 } else { 83 return plan.NewTableEditorIter(insertIter, ed), nil 84 } 85 } 86 87 func (b *BaseBuilder) buildDeleteFrom(ctx *sql.Context, n *plan.DeleteFrom, row sql.Row) (sql.RowIter, error) { 88 iter, err := b.buildNodeExec(ctx, n.Child, row) 89 if err != nil { 90 return nil, err 91 } 92 93 targets := n.GetDeleteTargets() 94 schemaPositionDeleters := make([]schemaPositionDeleter, len(targets)) 95 schema := n.Child.Schema() 96 97 for i, target := range targets { 98 deletable, err := plan.GetDeletable(target) 99 if err != nil { 100 return nil, err 101 } 102 deleter := deletable.Deleter(ctx) 103 104 // By default the sourceName in the schema is the table name, but if there is a 105 // table alias applied, then use that instead. 106 sourceName := deletable.Name() 107 transform.Inspect(target, func(node sql.Node) bool { 108 if tableAlias, ok := node.(*plan.TableAlias); ok { 109 sourceName = tableAlias.Name() 110 return false 111 } 112 return true 113 }) 114 115 start, end, err := findSourcePosition(schema, sourceName) 116 if err != nil { 117 return nil, err 118 } 119 schemaPositionDeleters[i] = schemaPositionDeleter{deleter, int(start), int(end)} 120 } 121 return newDeleteIter(iter, schema, schemaPositionDeleters...), nil 122 } 123 124 func (b *BaseBuilder) buildForeignKeyHandler(ctx *sql.Context, n *plan.ForeignKeyHandler, row sql.Row) (sql.RowIter, error) { 125 return b.buildNodeExec(ctx, n.OriginalNode, row) 126 } 127 128 func (b *BaseBuilder) buildUpdate(ctx *sql.Context, n *plan.Update, row sql.Row) (sql.RowIter, error) { 129 updatable, err := plan.GetUpdatable(n.Child) 130 if err != nil { 131 return nil, err 132 } 133 updater := updatable.Updater(ctx) 134 135 iter, err := b.buildNodeExec(ctx, n.Child, row) 136 if err != nil { 137 return nil, err 138 } 139 140 return newUpdateIter(iter, updatable.Schema(), updater, n.Checks(), n.Ignore), nil 141 } 142 143 func (b *BaseBuilder) buildDropForeignKey(ctx *sql.Context, n *plan.DropForeignKey, row sql.Row) (sql.RowIter, error) { 144 db, err := n.DbProvider.Database(ctx, n.Database()) 145 if err != nil { 146 return nil, err 147 } 148 tbl, ok, err := db.GetTableInsensitive(ctx, n.Table) 149 if err != nil { 150 return nil, err 151 } 152 if !ok { 153 return nil, sql.ErrTableNotFound.New(n.Table) 154 } 155 fkTbl, ok := tbl.(sql.ForeignKeyTable) 156 if !ok { 157 return nil, sql.ErrNoForeignKeySupport.New(n.Name) 158 } 159 err = fkTbl.DropForeignKey(ctx, n.Name) 160 if err != nil { 161 return nil, err 162 } 163 164 return rowIterWithOkResultWithZeroRowsAffected(), nil 165 } 166 167 func (b *BaseBuilder) buildDropTable(ctx *sql.Context, n *plan.DropTable, row sql.Row) (sql.RowIter, error) { 168 var err error 169 var curdb sql.Database 170 171 for _, table := range n.Tables { 172 tbl := table.(*plan.ResolvedTable) 173 curdb = tbl.SqlDatabase 174 175 droppable := tbl.SqlDatabase.(sql.TableDropper) 176 177 if fkTable, err := getForeignKeyTable(tbl); err == nil { 178 fkChecks, err := ctx.GetSessionVariable(ctx, "foreign_key_checks") 179 if err != nil { 180 return nil, err 181 } 182 if fkChecks.(int8) == 1 { 183 parentFks, err := fkTable.GetReferencedForeignKeys(ctx) 184 if err != nil { 185 return nil, err 186 } 187 for i, fk := range parentFks { 188 // ignore self referential foreign keys 189 if fk.Table != fk.ParentTable { 190 return nil, sql.ErrForeignKeyDropTable.New(fkTable.Name(), parentFks[i].Name) 191 } 192 } 193 } 194 fks, err := fkTable.GetDeclaredForeignKeys(ctx) 195 if err != nil { 196 return nil, err 197 } 198 for _, fk := range fks { 199 if err = fkTable.DropForeignKey(ctx, fk.Name); err != nil { 200 return nil, err 201 } 202 } 203 } 204 205 if hasFullText(ctx, tbl) { 206 if err = fulltext.DropAllIndexes(ctx, tbl.Table.(sql.IndexAddressableTable), droppable.(fulltext.Database)); err != nil { 207 return nil, err 208 } 209 } 210 211 err = droppable.DropTable(ctx, tbl.Name()) 212 if err != nil { 213 return nil, err 214 } 215 } 216 217 if len(n.TriggerNames) > 0 { 218 triggerDb, ok := curdb.(sql.TriggerDatabase) 219 if !ok { 220 tblNames, _ := n.TableNames() 221 return nil, fmt.Errorf(`tables %v are referenced in triggers %v, but database does not support triggers`, tblNames, n.TriggerNames) 222 } 223 //TODO: if dropping any triggers fail, then we'll be left in a state where triggers exist for a table that was dropped 224 for _, trigger := range n.TriggerNames { 225 err = triggerDb.DropTrigger(ctx, trigger) 226 if err != nil { 227 return nil, err 228 } 229 } 230 } 231 232 return rowIterWithOkResultWithZeroRowsAffected(), nil 233 } 234 235 func (b *BaseBuilder) buildTriggerRollback(ctx *sql.Context, n *plan.TriggerRollback, row sql.Row) (sql.RowIter, error) { 236 childIter, err := b.buildNodeExec(ctx, n.Child, row) 237 if err != nil { 238 return nil, err 239 } 240 241 ctx.GetLogger().Tracef("TriggerRollback creating savepoint: %s", SavePointName) 242 243 ts, ok := ctx.Session.(sql.TransactionSession) 244 if !ok { 245 return nil, fmt.Errorf("expected a sql.TransactionSession, but got %T", ctx.Session) 246 } 247 248 if err := ts.CreateSavepoint(ctx, ctx.GetTransaction(), SavePointName); err != nil { 249 ctx.GetLogger().WithError(err).Errorf("CreateSavepoint failed") 250 } 251 252 return &triggerRollbackIter{ 253 child: childIter, 254 hasSavepoint: true, 255 }, nil 256 } 257 258 func (b *BaseBuilder) buildAlterIndex(ctx *sql.Context, n *plan.AlterIndex, row sql.Row) (sql.RowIter, error) { 259 err := b.executeAlterIndex(ctx, n) 260 if err != nil { 261 return nil, err 262 } 263 264 return rowIterWithOkResultWithZeroRowsAffected(), nil 265 } 266 267 func (b *BaseBuilder) buildTriggerBeginEndBlock(ctx *sql.Context, n *plan.TriggerBeginEndBlock, row sql.Row) (sql.RowIter, error) { 268 return &triggerBlockIter{ 269 statements: n.Children(), 270 row: row, 271 once: &sync.Once{}, 272 }, nil 273 } 274 275 func (b *BaseBuilder) buildTriggerExecutor(ctx *sql.Context, n *plan.TriggerExecutor, row sql.Row) (sql.RowIter, error) { 276 childIter, err := b.buildNodeExec(ctx, n.Left(), row) 277 if err != nil { 278 return nil, err 279 } 280 281 return &triggerIter{ 282 child: childIter, 283 triggerTime: n.TriggerTime, 284 triggerEvent: n.TriggerEvent, 285 executionLogic: n.Right(), 286 ctx: ctx, 287 }, nil 288 } 289 290 func (b *BaseBuilder) buildInsertDestination(ctx *sql.Context, n *plan.InsertDestination, row sql.Row) (sql.RowIter, error) { 291 return b.buildNodeExec(ctx, n.Child, row) 292 } 293 294 func (b *BaseBuilder) buildRowUpdateAccumulator(ctx *sql.Context, n *plan.RowUpdateAccumulator, row sql.Row) (sql.RowIter, error) { 295 rowIter, err := b.buildNodeExec(ctx, n.Child(), row) 296 if err != nil { 297 return nil, err 298 } 299 300 clientFoundRowsToggled := (ctx.Client().Capabilities & mysql.CapabilityClientFoundRows) == mysql.CapabilityClientFoundRows 301 302 var rowHandler accumulatorRowHandler 303 switch n.RowUpdateType { 304 case plan.UpdateTypeInsert: 305 insertItr, err := findInsertIter(rowIter) 306 if err != nil { 307 return nil, err 308 } 309 310 rowHandler = &insertRowHandler{ 311 lastInsertIdGetter: insertItr.getAutoIncVal, 312 } 313 // TODO: some of these other row handlers also need to keep track of the last insert id 314 case plan.UpdateTypeReplace: 315 rowHandler = &replaceRowHandler{} 316 case plan.UpdateTypeDuplicateKeyUpdate: 317 rowHandler = &onDuplicateUpdateHandler{schema: n.Child().Schema(), clientFoundRowsCapability: clientFoundRowsToggled} 318 case plan.UpdateTypeUpdate: 319 schema := n.Child().Schema() 320 // the schema of the update node is a self-concatenation of the underlying table's, so split it in half for new / 321 // old row comparison purposes 322 rowHandler = &updateRowHandler{schema: schema[:len(schema)/2], clientFoundRowsCapability: clientFoundRowsToggled} 323 case plan.UpdateTypeDelete: 324 rowHandler = &deleteRowHandler{} 325 case plan.UpdateTypeJoinUpdate: 326 var schema sql.Schema 327 var updaterMap map[string]sql.RowUpdater 328 transform.Inspect(n.Child(), func(node sql.Node) bool { 329 switch node.(type) { 330 case *plan.JoinNode, *plan.Project: 331 schema = node.Schema() 332 return false 333 case *plan.UpdateJoin: 334 updaterMap = node.(*plan.UpdateJoin).Updaters 335 return true 336 } 337 338 return true 339 }) 340 341 if schema == nil { 342 return nil, fmt.Errorf("error: No JoinNode found in query plan to go along with an UpdateTypeJoinUpdate") 343 } 344 345 rowHandler = &updateJoinRowHandler{joinSchema: schema, tableMap: plan.RecreateTableSchemaFromJoinSchema(schema), updaterMap: updaterMap} 346 default: 347 panic(fmt.Sprintf("Unrecognized RowUpdateType %d", n.RowUpdateType)) 348 } 349 350 return &accumulatorIter{ 351 iter: rowIter, 352 updateRowHandler: rowHandler, 353 }, nil 354 } 355 356 func findInsertIter(rowIter sql.RowIter) (*insertIter, error) { 357 var insertItr *insertIter 358 switch rowIter := rowIter.(type) { 359 case *plan.TableEditorIter: 360 var ok bool 361 insertItr, ok = rowIter.InnerIter().(*insertIter) 362 if !ok { 363 return nil, fmt.Errorf("unexpected iter type %T", rowIter) 364 } 365 case *plan.CheckpointingTableEditorIter: 366 var ok bool 367 insertItr, ok = rowIter.InnerIter().(*insertIter) 368 if !ok { 369 return nil, fmt.Errorf("unexpected iter type %T", rowIter) 370 } 371 case *triggerIter: 372 var err error 373 insertItr, err = findInsertIter(rowIter.child) 374 if err != nil { 375 return nil, err 376 } 377 default: 378 return nil, fmt.Errorf("unexpected iter type %T", rowIter) 379 } 380 return insertItr, nil 381 } 382 383 func (b *BaseBuilder) buildTruncate(ctx *sql.Context, n *plan.Truncate, row sql.Row) (sql.RowIter, error) { 384 truncatable, err := plan.GetTruncatable(n.Child) 385 if err != nil { 386 return nil, err 387 } 388 //TODO: when performance schema summary tables are added, reset the columns to 0/NULL rather than remove rows 389 //TODO: close all handlers that were opened with "HANDLER OPEN" 390 391 removed, err := truncatable.Truncate(ctx) 392 if err != nil { 393 return nil, err 394 } 395 for _, col := range truncatable.Schema() { 396 if col.AutoIncrement { 397 aiTable, ok := truncatable.(sql.AutoIncrementTable) 398 if ok { 399 setter := aiTable.AutoIncrementSetter(ctx) 400 err = setter.SetAutoIncrementValue(ctx, uint64(1)) 401 if err != nil { 402 return nil, err 403 } 404 err = setter.Close(ctx) 405 if err != nil { 406 return nil, err 407 } 408 } 409 break 410 } 411 } 412 // If we've got Full-Text indexes, then we also need to clear those tables 413 if hasFullText(ctx, truncatable) { 414 if err = rebuildFullText(ctx, truncatable.Name(), plan.GetDatabase(n.Child)); err != nil { 415 return nil, err 416 } 417 } 418 return sql.RowsToRowIter(sql.NewRow(types.NewOkResult(removed))), nil 419 } 420 421 func (b *BaseBuilder) buildUpdateSource(ctx *sql.Context, n *plan.UpdateSource, row sql.Row) (sql.RowIter, error) { 422 rowIter, err := b.buildNodeExec(ctx, n.Child, row) 423 if err != nil { 424 return nil, err 425 } 426 427 schema, err := n.GetChildSchema() 428 if err != nil { 429 return nil, err 430 } 431 432 return &updateSourceIter{ 433 childIter: rowIter, 434 updateExprs: n.UpdateExprs, 435 tableSchema: schema, 436 ignore: n.Ignore, 437 }, nil 438 } 439 440 func (b *BaseBuilder) buildUpdateJoin(ctx *sql.Context, n *plan.UpdateJoin, row sql.Row) (sql.RowIter, error) { 441 ji, err := b.buildNodeExec(ctx, n.Child, row) 442 if err != nil { 443 return nil, err 444 } 445 446 return &updateJoinIter{ 447 updateSourceIter: ji, 448 joinSchema: n.Child.(*plan.UpdateSource).Child.Schema(), 449 updaters: n.Updaters, 450 caches: make(map[string]sql.KeyValueCache), 451 disposals: make(map[string]sql.DisposeFunc), 452 joinNode: n.Child.(*plan.UpdateSource).Child, 453 }, nil 454 }