github.com/dolthub/go-mysql-server@v0.18.0/sql/rowexec/dml_iters.go (about) 1 // Copyright 2023 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package rowexec 16 17 import ( 18 "fmt" 19 "io" 20 "sync" 21 22 "github.com/dolthub/go-mysql-server/sql" 23 "github.com/dolthub/go-mysql-server/sql/expression" 24 "github.com/dolthub/go-mysql-server/sql/plan" 25 "github.com/dolthub/go-mysql-server/sql/transform" 26 "github.com/dolthub/go-mysql-server/sql/types" 27 ) 28 29 const SavePointName = "__go_mysql_server_starting_savepoint__" 30 31 type triggerRollbackIter struct { 32 child sql.RowIter 33 hasSavepoint bool 34 } 35 36 func (t *triggerRollbackIter) Next(ctx *sql.Context) (row sql.Row, returnErr error) { 37 childRow, err := t.child.Next(ctx) 38 39 ts, ok := ctx.Session.(sql.TransactionSession) 40 if !ok { 41 return nil, fmt.Errorf("expected a sql.TransactionSession, but got %T", ctx.Session) 42 } 43 44 // Rollback if error occurred 45 if err != nil && err != io.EOF { 46 if err := ts.RollbackToSavepoint(ctx, ctx.GetTransaction(), SavePointName); err != nil { 47 ctx.GetLogger().WithError(err).Errorf("Unexpected error when calling RollbackToSavePoint during triggerRollbackIter.Next()") 48 } 49 if err := ts.ReleaseSavepoint(ctx, ctx.GetTransaction(), SavePointName); err != nil { 50 ctx.GetLogger().WithError(err).Errorf("Unexpected error when calling ReleaseSavepoint during triggerRollbackIter.Next()") 51 } else { 52 t.hasSavepoint = false 53 } 54 } 55 56 return childRow, err 57 } 58 59 func (t *triggerRollbackIter) Close(ctx *sql.Context) error { 60 ts, ok := ctx.Session.(sql.TransactionSession) 61 if !ok { 62 return fmt.Errorf("expected a sql.TransactionSession, but got %T", ctx.Session) 63 } 64 65 if t.hasSavepoint { 66 if err := ts.ReleaseSavepoint(ctx, ctx.GetTransaction(), SavePointName); err != nil { 67 ctx.GetLogger().WithError(err).Errorf("Unexpected error when calling ReleaseSavepoint during triggerRollbackIter.Close()") 68 } 69 t.hasSavepoint = false 70 } 71 return t.child.Close(ctx) 72 } 73 74 // triggerBlockIter is the sql.RowIter for TRIGGER BEGIN/END blocks, which operate differently than normal blocks. 75 type triggerBlockIter struct { 76 statements []sql.Node 77 row sql.Row 78 once *sync.Once 79 b *BaseBuilder 80 } 81 82 var _ sql.RowIter = (*triggerBlockIter)(nil) 83 84 // Next implements the sql.RowIter interface. 85 func (i *triggerBlockIter) Next(ctx *sql.Context) (sql.Row, error) { 86 run := false 87 i.once.Do(func() { 88 run = true 89 }) 90 91 if !run { 92 return nil, io.EOF 93 } 94 95 row := i.row 96 for _, s := range i.statements { 97 subIter, err := i.b.buildNodeExec(ctx, s, row) 98 if err != nil { 99 return nil, err 100 } 101 102 for { 103 newRow, err := subIter.Next(ctx) 104 if err == io.EOF { 105 err := subIter.Close(ctx) 106 if err != nil { 107 return nil, err 108 } 109 break 110 } else if err != nil { 111 _ = subIter.Close(ctx) 112 return nil, err 113 } 114 115 // We only return the result of a trigger block statement in certain cases, specifically when we are setting the 116 // value of new.field, so that the wrapping iterator can use it for the insert / update. Otherwise, this iterator 117 // always returns its input row. 118 if shouldUseTriggerStatementForReturnRow(s) { 119 row = newRow[len(newRow)/2:] 120 } 121 } 122 } 123 124 return row, nil 125 } 126 127 // shouldUseTriggerStatementForReturnRow returns whether the statement has Set node that contains GetField expression, 128 // which means whether there is column value update. The Set node can be inside other nodes, so need to inspect all nodes 129 // of the given node. 130 func shouldUseTriggerStatementForReturnRow(stmt sql.Node) bool { 131 hasSetField := false 132 transform.Inspect(stmt, func(n sql.Node) bool { 133 switch logic := n.(type) { 134 case *plan.Set: 135 for _, expr := range logic.Exprs { 136 sql.Inspect(expr.(*expression.SetField).LeftChild, func(e sql.Expression) bool { 137 if _, ok := e.(*expression.GetField); ok { 138 hasSetField = true 139 return false 140 } 141 return true 142 }) 143 } 144 } 145 return true 146 }) 147 return hasSetField 148 } 149 150 // Close implements the sql.RowIter interface. 151 func (i *triggerBlockIter) Close(*sql.Context) error { 152 return nil 153 } 154 155 type triggerIter struct { 156 child sql.RowIter 157 executionLogic sql.Node 158 triggerTime plan.TriggerTime 159 triggerEvent plan.TriggerEvent 160 ctx *sql.Context 161 b *BaseBuilder 162 } 163 164 // prependRowInPlanForTriggerExecution returns a transformation function that prepends the row given to any row source in a query 165 // plan. Any source of rows, as well as any node that alters the schema of its children, will be wrapped so that its 166 // result rows are prepended with the row given. 167 func prependRowInPlanForTriggerExecution(row sql.Row) func(c transform.Context) (sql.Node, transform.TreeIdentity, error) { 168 return func(c transform.Context) (sql.Node, transform.TreeIdentity, error) { 169 switch n := c.Node.(type) { 170 case *plan.Project: 171 // Only prepend rows for projects that aren't the input to inserts and other triggers 172 switch c.Parent.(type) { 173 case *plan.InsertInto, *plan.TriggerExecutor: 174 return n, transform.SameTree, nil 175 default: 176 return plan.NewPrependNode(n, row), transform.NewTree, nil 177 } 178 case *plan.ResolvedTable, *plan.IndexedTableAccess: 179 return plan.NewPrependNode(n, row), transform.NewTree, nil 180 default: 181 return n, transform.SameTree, nil 182 } 183 } 184 } 185 186 func (t *triggerIter) Next(ctx *sql.Context) (row sql.Row, returnErr error) { 187 childRow, err := t.child.Next(ctx) 188 if err != nil { 189 return nil, err 190 } 191 192 // Wrap the execution logic with the current child row before executing it. 193 logic, _, err := transform.NodeWithCtx(t.executionLogic, nil, prependRowInPlanForTriggerExecution(childRow)) 194 if err != nil { 195 return nil, err 196 } 197 198 // We don't do anything interesting with this subcontext yet, but it's a good idea to cancel it independently of the 199 // parent context if something goes wrong in trigger execution. 200 ctx, cancelFunc := t.ctx.NewSubContext() 201 defer cancelFunc() 202 203 logicIter, err := t.b.buildNodeExec(ctx, logic, childRow) 204 if err != nil { 205 return nil, err 206 } 207 208 defer func() { 209 err := logicIter.Close(t.ctx) 210 if returnErr == nil { 211 returnErr = err 212 } 213 }() 214 215 var logicRow sql.Row 216 for { 217 row, err := logicIter.Next(ctx) 218 if err == io.EOF { 219 break 220 } 221 if err != nil { 222 return nil, err 223 } 224 logicRow = row 225 } 226 227 // For some logic statements, we want to return the result of the logic operation as our row, e.g. a Set that alters 228 // the fields of the new row 229 if ok, returnRow := shouldUseLogicResult(logic, logicRow); ok { 230 return returnRow, nil 231 } 232 233 return childRow, nil 234 } 235 236 func shouldUseLogicResult(logic sql.Node, row sql.Row) (bool, sql.Row) { 237 switch logic := logic.(type) { 238 // TODO: are there other statement types that we should use here? 239 case *plan.Set: 240 hasSetField := false 241 for _, expr := range logic.Exprs { 242 sql.Inspect(expr.(*expression.SetField).LeftChild, func(e sql.Expression) bool { 243 if _, ok := e.(*expression.GetField); ok { 244 hasSetField = true 245 return false 246 } 247 return true 248 }) 249 } 250 return hasSetField, row[len(row)/2:] 251 case *plan.TriggerBeginEndBlock: 252 hasSetField := false 253 transform.Inspect(logic, func(n sql.Node) bool { 254 set, ok := n.(*plan.Set) 255 if !ok { 256 return true 257 } 258 for _, expr := range set.Exprs { 259 sql.Inspect(expr.(*expression.SetField).LeftChild, func(e sql.Expression) bool { 260 if _, ok := e.(*expression.GetField); ok { 261 hasSetField = true 262 return false 263 } 264 return true 265 }) 266 } 267 return !hasSetField 268 }) 269 return hasSetField, row 270 default: 271 return false, nil 272 } 273 } 274 275 func (t *triggerIter) Close(ctx *sql.Context) error { 276 return t.child.Close(ctx) 277 } 278 279 type accumulatorRowHandler interface { 280 handleRowUpdate(row sql.Row) error 281 okResult() types.OkResult 282 } 283 284 // TODO: Extend this to UPDATE IGNORE JOIN 285 type updateIgnoreAccumulatorRowHandler interface { 286 accumulatorRowHandler 287 handleRowUpdateWithIgnore(row sql.Row, ignore bool) error 288 } 289 290 type insertRowHandler struct { 291 rowsAffected int 292 lastInsertId uint64 293 updatedAutoIncrementValue bool 294 lastInsertIdGetter func(row sql.Row) int64 295 } 296 297 func (i *insertRowHandler) handleRowUpdate(row sql.Row) error { 298 if !i.updatedAutoIncrementValue { 299 i.updatedAutoIncrementValue = true 300 i.lastInsertId = uint64(i.lastInsertIdGetter(row)) 301 } 302 i.rowsAffected++ 303 return nil 304 } 305 306 func (i *insertRowHandler) okResult() types.OkResult { 307 return types.OkResult{ 308 RowsAffected: uint64(i.rowsAffected), 309 InsertID: i.lastInsertId, 310 } 311 } 312 313 type replaceRowHandler struct { 314 rowsAffected int 315 } 316 317 func (r *replaceRowHandler) handleRowUpdate(row sql.Row) error { 318 r.rowsAffected++ 319 320 // If a row was deleted as well as inserted, increment the counter again. A row was deleted if at least one column in 321 // the first half of the row is non-null. 322 for i := 0; i < len(row)/2; i++ { 323 if row[i] != nil { 324 r.rowsAffected++ 325 break 326 } 327 } 328 329 return nil 330 } 331 332 func (r *replaceRowHandler) okResult() types.OkResult { 333 return types.NewOkResult(r.rowsAffected) 334 } 335 336 type onDuplicateUpdateHandler struct { 337 rowsAffected int 338 schema sql.Schema 339 clientFoundRowsCapability bool 340 } 341 342 func (o *onDuplicateUpdateHandler) handleRowUpdate(row sql.Row) error { 343 // See https://dev.mysql.com/doc/refman/8.0/en/insert-on-duplicate.html for row count semantics 344 // If a row was inserted, increment by 1 345 if len(row) == len(o.schema) { 346 o.rowsAffected++ 347 return nil 348 } 349 350 // Otherwise (a row was updated), increment by 2 if the row changed, 0 if not 351 oldRow := row[:len(row)/2] 352 newRow := row[len(row)/2:] 353 if equals, err := oldRow.Equals(newRow, o.schema); err == nil { 354 if equals { 355 // Ig the CLIENT_FOUND_ROWS capabilities flag is set, increment by 1 if a row stays the same. 356 if o.clientFoundRowsCapability { 357 o.rowsAffected++ 358 } 359 } else { 360 o.rowsAffected += 2 361 } 362 } else { 363 o.rowsAffected++ 364 } 365 366 return nil 367 } 368 369 func (o *onDuplicateUpdateHandler) okResult() types.OkResult { 370 return types.NewOkResult(o.rowsAffected) 371 } 372 373 type updateRowHandler struct { 374 rowsMatched int 375 rowsAffected int 376 schema sql.Schema 377 clientFoundRowsCapability bool 378 } 379 380 func (u *updateRowHandler) handleRowUpdate(row sql.Row) error { 381 u.rowsMatched++ 382 oldRow := row[:len(row)/2] 383 newRow := row[len(row)/2:] 384 if equals, err := oldRow.Equals(newRow, u.schema); err == nil { 385 if !equals { 386 u.rowsAffected++ 387 } 388 } else { 389 return err 390 } 391 return nil 392 } 393 394 func (u *updateRowHandler) handleRowUpdateWithIgnore(row sql.Row, ignore bool) error { 395 if !ignore { 396 return u.handleRowUpdate(row) 397 } 398 399 u.rowsMatched++ 400 return nil 401 } 402 403 func (u *updateRowHandler) okResult() types.OkResult { 404 affected := u.rowsAffected 405 if u.clientFoundRowsCapability { 406 affected = u.rowsMatched 407 } 408 return types.OkResult{ 409 RowsAffected: uint64(affected), 410 Info: plan.UpdateInfo{ 411 Matched: u.rowsMatched, 412 Updated: u.rowsAffected, 413 Warnings: 0, 414 }, 415 } 416 } 417 418 func (u *updateRowHandler) RowsMatched() int64 { 419 return int64(u.rowsMatched) 420 } 421 422 // updateJoinRowHandler handles row update count for all UPDATEs that use a JOIN. 423 type updateJoinRowHandler struct { 424 rowsMatched int 425 rowsAffected int 426 joinSchema sql.Schema 427 tableMap map[string]sql.Schema // Needs to only be the tables that can be updated. 428 updaterMap map[string]sql.RowUpdater 429 } 430 431 func (u *updateJoinRowHandler) handleRowUpdate(row sql.Row) error { 432 oldJoinRow := row[:len(row)/2] 433 newJoinRow := row[len(row)/2:] 434 435 tableToOldRow := plan.SplitRowIntoTableRowMap(oldJoinRow, u.joinSchema) 436 tableToNewRow := plan.SplitRowIntoTableRowMap(newJoinRow, u.joinSchema) 437 438 for tableName, _ := range u.updaterMap { 439 u.rowsMatched++ // TODO: This currently returns the incorrect answer 440 tableOldRow := tableToOldRow[tableName] 441 tableNewRow := tableToNewRow[tableName] 442 if equals, err := tableOldRow.Equals(tableNewRow, u.tableMap[tableName]); err == nil { 443 if !equals { 444 u.rowsAffected++ 445 } 446 } else { 447 return err 448 } 449 } 450 return nil 451 } 452 453 func (u *updateJoinRowHandler) okResult() types.OkResult { 454 return types.OkResult{ 455 RowsAffected: uint64(u.rowsAffected), 456 Info: plan.UpdateInfo{ 457 Matched: u.rowsMatched, 458 Updated: u.rowsAffected, 459 Warnings: 0, 460 }, 461 } 462 } 463 464 func (u *updateJoinRowHandler) RowsMatched() int64 { 465 return int64(u.rowsMatched) 466 } 467 468 type deleteRowHandler struct { 469 rowsAffected int 470 } 471 472 func (u *deleteRowHandler) handleRowUpdate(row sql.Row) error { 473 u.rowsAffected++ 474 return nil 475 } 476 477 func (u *deleteRowHandler) okResult() types.OkResult { 478 return types.NewOkResult(u.rowsAffected) 479 } 480 481 type accumulatorIter struct { 482 iter sql.RowIter 483 once sync.Once 484 updateRowHandler accumulatorRowHandler 485 } 486 487 func (a *accumulatorIter) Next(ctx *sql.Context) (r sql.Row, err error) { 488 run := false 489 a.once.Do(func() { 490 run = true 491 }) 492 493 if !run { 494 return nil, io.EOF 495 } 496 497 oldLastInsertId := ctx.Session.GetLastQueryInfo(sql.LastInsertId) 498 if oldLastInsertId != 0 { 499 ctx.Session.SetLastQueryInfo(sql.LastInsertId, -1) 500 } 501 502 // We close our child iterator before returning any results. In 503 // particular, the LOAD DATA source iterator needs to be closed before 504 // results are returned. 505 defer func() { 506 cerr := a.iter.Close(ctx) 507 if err == nil { 508 err = cerr 509 } 510 }() 511 512 for { 513 row, err := a.iter.Next(ctx) 514 igErr, isIg := err.(sql.IgnorableError) 515 select { 516 case <-ctx.Done(): 517 return nil, ctx.Err() 518 default: 519 } 520 if err == io.EOF { 521 // TODO: The information flow here is pretty gnarly. We 522 // set some session variables based on the result, and 523 // we actually use a session variable to set 524 // InsertID. This should be improved. 525 526 // UPDATE statements also set FoundRows to the number of rows that 527 // matched the WHERE clause, same as a SELECT. 528 if ma, ok := a.updateRowHandler.(matchingAccumulator); ok { 529 ctx.SetLastQueryInfo(sql.FoundRows, ma.RowsMatched()) 530 } 531 532 newLastInsertId := ctx.Session.GetLastQueryInfo(sql.LastInsertId) 533 if newLastInsertId == -1 { 534 ctx.Session.SetLastQueryInfo(sql.LastInsertId, oldLastInsertId) 535 } 536 537 res := a.updateRowHandler.okResult() // TODO: Should add warnings here 538 539 // For some update accumulators, we don't accurately track the last insert ID in the handler and need to set 540 // it manually in the result by getting it from the session. This doesn't work correctly in all cases and needs 541 // to be fixed. See comment in buildRowUpdateAccumulator in rowexec/dml.go 542 switch a.updateRowHandler.(type) { 543 case *onDuplicateUpdateHandler, *replaceRowHandler: 544 res.InsertID = uint64(newLastInsertId) 545 } 546 547 // By definition, ROW_COUNT() is equal to RowsAffected. 548 ctx.SetLastQueryInfo(sql.RowCount, int64(res.RowsAffected)) 549 550 return sql.NewRow(res), nil 551 } else if isIg { 552 if ui, ok := a.updateRowHandler.(updateIgnoreAccumulatorRowHandler); ok { 553 err = ui.handleRowUpdateWithIgnore(igErr.OffendingRow, true) 554 if err != nil { 555 return nil, err 556 } 557 } 558 } else if err != nil { 559 return nil, err 560 } else { 561 err = a.updateRowHandler.handleRowUpdate(row) 562 if err != nil { 563 return nil, err 564 } 565 } 566 } 567 } 568 569 func (a *accumulatorIter) Close(ctx *sql.Context) error { 570 return nil 571 } 572 573 type matchingAccumulator interface { 574 RowsMatched() int64 575 } 576 577 type updateSourceIter struct { 578 childIter sql.RowIter 579 updateExprs []sql.Expression 580 tableSchema sql.Schema 581 ignore bool 582 } 583 584 func (u *updateSourceIter) Next(ctx *sql.Context) (sql.Row, error) { 585 oldRow, err := u.childIter.Next(ctx) 586 if err != nil { 587 return nil, err 588 } 589 590 newRow, err := applyUpdateExpressionsWithIgnore(ctx, u.updateExprs, u.tableSchema, oldRow, u.ignore) 591 if err != nil { 592 return nil, err 593 } 594 595 // Reduce the row to the length of the schema. The length can differ when some update values come from an outer 596 // scope, which will be the first N values in the row. 597 // TODO: handle this in the analyzer instead? 598 expectedSchemaLen := len(u.tableSchema) 599 if expectedSchemaLen < len(oldRow) { 600 oldRow = oldRow[len(oldRow)-expectedSchemaLen:] 601 newRow = newRow[len(newRow)-expectedSchemaLen:] 602 } 603 604 return oldRow.Append(newRow), nil 605 } 606 607 func (u *updateSourceIter) Close(ctx *sql.Context) error { 608 return u.childIter.Close(ctx) 609 }