github.com/dolthub/go-mysql-server@v0.18.0/sql/rowexec/dml_iters.go (about)

     1  // Copyright 2023 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package rowexec
    16  
    17  import (
    18  	"fmt"
    19  	"io"
    20  	"sync"
    21  
    22  	"github.com/dolthub/go-mysql-server/sql"
    23  	"github.com/dolthub/go-mysql-server/sql/expression"
    24  	"github.com/dolthub/go-mysql-server/sql/plan"
    25  	"github.com/dolthub/go-mysql-server/sql/transform"
    26  	"github.com/dolthub/go-mysql-server/sql/types"
    27  )
    28  
    29  const SavePointName = "__go_mysql_server_starting_savepoint__"
    30  
    31  type triggerRollbackIter struct {
    32  	child        sql.RowIter
    33  	hasSavepoint bool
    34  }
    35  
    36  func (t *triggerRollbackIter) Next(ctx *sql.Context) (row sql.Row, returnErr error) {
    37  	childRow, err := t.child.Next(ctx)
    38  
    39  	ts, ok := ctx.Session.(sql.TransactionSession)
    40  	if !ok {
    41  		return nil, fmt.Errorf("expected a sql.TransactionSession, but got %T", ctx.Session)
    42  	}
    43  
    44  	// Rollback if error occurred
    45  	if err != nil && err != io.EOF {
    46  		if err := ts.RollbackToSavepoint(ctx, ctx.GetTransaction(), SavePointName); err != nil {
    47  			ctx.GetLogger().WithError(err).Errorf("Unexpected error when calling RollbackToSavePoint during triggerRollbackIter.Next()")
    48  		}
    49  		if err := ts.ReleaseSavepoint(ctx, ctx.GetTransaction(), SavePointName); err != nil {
    50  			ctx.GetLogger().WithError(err).Errorf("Unexpected error when calling ReleaseSavepoint during triggerRollbackIter.Next()")
    51  		} else {
    52  			t.hasSavepoint = false
    53  		}
    54  	}
    55  
    56  	return childRow, err
    57  }
    58  
    59  func (t *triggerRollbackIter) Close(ctx *sql.Context) error {
    60  	ts, ok := ctx.Session.(sql.TransactionSession)
    61  	if !ok {
    62  		return fmt.Errorf("expected a sql.TransactionSession, but got %T", ctx.Session)
    63  	}
    64  
    65  	if t.hasSavepoint {
    66  		if err := ts.ReleaseSavepoint(ctx, ctx.GetTransaction(), SavePointName); err != nil {
    67  			ctx.GetLogger().WithError(err).Errorf("Unexpected error when calling ReleaseSavepoint during triggerRollbackIter.Close()")
    68  		}
    69  		t.hasSavepoint = false
    70  	}
    71  	return t.child.Close(ctx)
    72  }
    73  
    74  // triggerBlockIter is the sql.RowIter for TRIGGER BEGIN/END blocks, which operate differently than normal blocks.
    75  type triggerBlockIter struct {
    76  	statements []sql.Node
    77  	row        sql.Row
    78  	once       *sync.Once
    79  	b          *BaseBuilder
    80  }
    81  
    82  var _ sql.RowIter = (*triggerBlockIter)(nil)
    83  
    84  // Next implements the sql.RowIter interface.
    85  func (i *triggerBlockIter) Next(ctx *sql.Context) (sql.Row, error) {
    86  	run := false
    87  	i.once.Do(func() {
    88  		run = true
    89  	})
    90  
    91  	if !run {
    92  		return nil, io.EOF
    93  	}
    94  
    95  	row := i.row
    96  	for _, s := range i.statements {
    97  		subIter, err := i.b.buildNodeExec(ctx, s, row)
    98  		if err != nil {
    99  			return nil, err
   100  		}
   101  
   102  		for {
   103  			newRow, err := subIter.Next(ctx)
   104  			if err == io.EOF {
   105  				err := subIter.Close(ctx)
   106  				if err != nil {
   107  					return nil, err
   108  				}
   109  				break
   110  			} else if err != nil {
   111  				_ = subIter.Close(ctx)
   112  				return nil, err
   113  			}
   114  
   115  			// We only return the result of a trigger block statement in certain cases, specifically when we are setting the
   116  			// value of new.field, so that the wrapping iterator can use it for the insert / update. Otherwise, this iterator
   117  			// always returns its input row.
   118  			if shouldUseTriggerStatementForReturnRow(s) {
   119  				row = newRow[len(newRow)/2:]
   120  			}
   121  		}
   122  	}
   123  
   124  	return row, nil
   125  }
   126  
   127  // shouldUseTriggerStatementForReturnRow returns whether the statement has Set node that contains GetField expression,
   128  // which means whether there is column value update. The Set node can be inside other nodes, so need to inspect all nodes
   129  // of the given node.
   130  func shouldUseTriggerStatementForReturnRow(stmt sql.Node) bool {
   131  	hasSetField := false
   132  	transform.Inspect(stmt, func(n sql.Node) bool {
   133  		switch logic := n.(type) {
   134  		case *plan.Set:
   135  			for _, expr := range logic.Exprs {
   136  				sql.Inspect(expr.(*expression.SetField).LeftChild, func(e sql.Expression) bool {
   137  					if _, ok := e.(*expression.GetField); ok {
   138  						hasSetField = true
   139  						return false
   140  					}
   141  					return true
   142  				})
   143  			}
   144  		}
   145  		return true
   146  	})
   147  	return hasSetField
   148  }
   149  
   150  // Close implements the sql.RowIter interface.
   151  func (i *triggerBlockIter) Close(*sql.Context) error {
   152  	return nil
   153  }
   154  
   155  type triggerIter struct {
   156  	child          sql.RowIter
   157  	executionLogic sql.Node
   158  	triggerTime    plan.TriggerTime
   159  	triggerEvent   plan.TriggerEvent
   160  	ctx            *sql.Context
   161  	b              *BaseBuilder
   162  }
   163  
   164  // prependRowInPlanForTriggerExecution returns a transformation function that prepends the row given to any row source in a query
   165  // plan. Any source of rows, as well as any node that alters the schema of its children, will be wrapped so that its
   166  // result rows are prepended with the row given.
   167  func prependRowInPlanForTriggerExecution(row sql.Row) func(c transform.Context) (sql.Node, transform.TreeIdentity, error) {
   168  	return func(c transform.Context) (sql.Node, transform.TreeIdentity, error) {
   169  		switch n := c.Node.(type) {
   170  		case *plan.Project:
   171  			// Only prepend rows for projects that aren't the input to inserts and other triggers
   172  			switch c.Parent.(type) {
   173  			case *plan.InsertInto, *plan.TriggerExecutor:
   174  				return n, transform.SameTree, nil
   175  			default:
   176  				return plan.NewPrependNode(n, row), transform.NewTree, nil
   177  			}
   178  		case *plan.ResolvedTable, *plan.IndexedTableAccess:
   179  			return plan.NewPrependNode(n, row), transform.NewTree, nil
   180  		default:
   181  			return n, transform.SameTree, nil
   182  		}
   183  	}
   184  }
   185  
   186  func (t *triggerIter) Next(ctx *sql.Context) (row sql.Row, returnErr error) {
   187  	childRow, err := t.child.Next(ctx)
   188  	if err != nil {
   189  		return nil, err
   190  	}
   191  
   192  	// Wrap the execution logic with the current child row before executing it.
   193  	logic, _, err := transform.NodeWithCtx(t.executionLogic, nil, prependRowInPlanForTriggerExecution(childRow))
   194  	if err != nil {
   195  		return nil, err
   196  	}
   197  
   198  	// We don't do anything interesting with this subcontext yet, but it's a good idea to cancel it independently of the
   199  	// parent context if something goes wrong in trigger execution.
   200  	ctx, cancelFunc := t.ctx.NewSubContext()
   201  	defer cancelFunc()
   202  
   203  	logicIter, err := t.b.buildNodeExec(ctx, logic, childRow)
   204  	if err != nil {
   205  		return nil, err
   206  	}
   207  
   208  	defer func() {
   209  		err := logicIter.Close(t.ctx)
   210  		if returnErr == nil {
   211  			returnErr = err
   212  		}
   213  	}()
   214  
   215  	var logicRow sql.Row
   216  	for {
   217  		row, err := logicIter.Next(ctx)
   218  		if err == io.EOF {
   219  			break
   220  		}
   221  		if err != nil {
   222  			return nil, err
   223  		}
   224  		logicRow = row
   225  	}
   226  
   227  	// For some logic statements, we want to return the result of the logic operation as our row, e.g. a Set that alters
   228  	// the fields of the new row
   229  	if ok, returnRow := shouldUseLogicResult(logic, logicRow); ok {
   230  		return returnRow, nil
   231  	}
   232  
   233  	return childRow, nil
   234  }
   235  
   236  func shouldUseLogicResult(logic sql.Node, row sql.Row) (bool, sql.Row) {
   237  	switch logic := logic.(type) {
   238  	// TODO: are there other statement types that we should use here?
   239  	case *plan.Set:
   240  		hasSetField := false
   241  		for _, expr := range logic.Exprs {
   242  			sql.Inspect(expr.(*expression.SetField).LeftChild, func(e sql.Expression) bool {
   243  				if _, ok := e.(*expression.GetField); ok {
   244  					hasSetField = true
   245  					return false
   246  				}
   247  				return true
   248  			})
   249  		}
   250  		return hasSetField, row[len(row)/2:]
   251  	case *plan.TriggerBeginEndBlock:
   252  		hasSetField := false
   253  		transform.Inspect(logic, func(n sql.Node) bool {
   254  			set, ok := n.(*plan.Set)
   255  			if !ok {
   256  				return true
   257  			}
   258  			for _, expr := range set.Exprs {
   259  				sql.Inspect(expr.(*expression.SetField).LeftChild, func(e sql.Expression) bool {
   260  					if _, ok := e.(*expression.GetField); ok {
   261  						hasSetField = true
   262  						return false
   263  					}
   264  					return true
   265  				})
   266  			}
   267  			return !hasSetField
   268  		})
   269  		return hasSetField, row
   270  	default:
   271  		return false, nil
   272  	}
   273  }
   274  
   275  func (t *triggerIter) Close(ctx *sql.Context) error {
   276  	return t.child.Close(ctx)
   277  }
   278  
   279  type accumulatorRowHandler interface {
   280  	handleRowUpdate(row sql.Row) error
   281  	okResult() types.OkResult
   282  }
   283  
   284  // TODO: Extend this to UPDATE IGNORE JOIN
   285  type updateIgnoreAccumulatorRowHandler interface {
   286  	accumulatorRowHandler
   287  	handleRowUpdateWithIgnore(row sql.Row, ignore bool) error
   288  }
   289  
   290  type insertRowHandler struct {
   291  	rowsAffected              int
   292  	lastInsertId              uint64
   293  	updatedAutoIncrementValue bool
   294  	lastInsertIdGetter        func(row sql.Row) int64
   295  }
   296  
   297  func (i *insertRowHandler) handleRowUpdate(row sql.Row) error {
   298  	if !i.updatedAutoIncrementValue {
   299  		i.updatedAutoIncrementValue = true
   300  		i.lastInsertId = uint64(i.lastInsertIdGetter(row))
   301  	}
   302  	i.rowsAffected++
   303  	return nil
   304  }
   305  
   306  func (i *insertRowHandler) okResult() types.OkResult {
   307  	return types.OkResult{
   308  		RowsAffected: uint64(i.rowsAffected),
   309  		InsertID:     i.lastInsertId,
   310  	}
   311  }
   312  
   313  type replaceRowHandler struct {
   314  	rowsAffected int
   315  }
   316  
   317  func (r *replaceRowHandler) handleRowUpdate(row sql.Row) error {
   318  	r.rowsAffected++
   319  
   320  	// If a row was deleted as well as inserted, increment the counter again. A row was deleted if at least one column in
   321  	// the first half of the row is non-null.
   322  	for i := 0; i < len(row)/2; i++ {
   323  		if row[i] != nil {
   324  			r.rowsAffected++
   325  			break
   326  		}
   327  	}
   328  
   329  	return nil
   330  }
   331  
   332  func (r *replaceRowHandler) okResult() types.OkResult {
   333  	return types.NewOkResult(r.rowsAffected)
   334  }
   335  
   336  type onDuplicateUpdateHandler struct {
   337  	rowsAffected              int
   338  	schema                    sql.Schema
   339  	clientFoundRowsCapability bool
   340  }
   341  
   342  func (o *onDuplicateUpdateHandler) handleRowUpdate(row sql.Row) error {
   343  	// See https://dev.mysql.com/doc/refman/8.0/en/insert-on-duplicate.html for row count semantics
   344  	// If a row was inserted, increment by 1
   345  	if len(row) == len(o.schema) {
   346  		o.rowsAffected++
   347  		return nil
   348  	}
   349  
   350  	// Otherwise (a row was updated), increment by 2 if the row changed, 0 if not
   351  	oldRow := row[:len(row)/2]
   352  	newRow := row[len(row)/2:]
   353  	if equals, err := oldRow.Equals(newRow, o.schema); err == nil {
   354  		if equals {
   355  			// Ig the CLIENT_FOUND_ROWS capabilities flag is set, increment by 1 if a row stays the same.
   356  			if o.clientFoundRowsCapability {
   357  				o.rowsAffected++
   358  			}
   359  		} else {
   360  			o.rowsAffected += 2
   361  		}
   362  	} else {
   363  		o.rowsAffected++
   364  	}
   365  
   366  	return nil
   367  }
   368  
   369  func (o *onDuplicateUpdateHandler) okResult() types.OkResult {
   370  	return types.NewOkResult(o.rowsAffected)
   371  }
   372  
   373  type updateRowHandler struct {
   374  	rowsMatched               int
   375  	rowsAffected              int
   376  	schema                    sql.Schema
   377  	clientFoundRowsCapability bool
   378  }
   379  
   380  func (u *updateRowHandler) handleRowUpdate(row sql.Row) error {
   381  	u.rowsMatched++
   382  	oldRow := row[:len(row)/2]
   383  	newRow := row[len(row)/2:]
   384  	if equals, err := oldRow.Equals(newRow, u.schema); err == nil {
   385  		if !equals {
   386  			u.rowsAffected++
   387  		}
   388  	} else {
   389  		return err
   390  	}
   391  	return nil
   392  }
   393  
   394  func (u *updateRowHandler) handleRowUpdateWithIgnore(row sql.Row, ignore bool) error {
   395  	if !ignore {
   396  		return u.handleRowUpdate(row)
   397  	}
   398  
   399  	u.rowsMatched++
   400  	return nil
   401  }
   402  
   403  func (u *updateRowHandler) okResult() types.OkResult {
   404  	affected := u.rowsAffected
   405  	if u.clientFoundRowsCapability {
   406  		affected = u.rowsMatched
   407  	}
   408  	return types.OkResult{
   409  		RowsAffected: uint64(affected),
   410  		Info: plan.UpdateInfo{
   411  			Matched:  u.rowsMatched,
   412  			Updated:  u.rowsAffected,
   413  			Warnings: 0,
   414  		},
   415  	}
   416  }
   417  
   418  func (u *updateRowHandler) RowsMatched() int64 {
   419  	return int64(u.rowsMatched)
   420  }
   421  
   422  // updateJoinRowHandler handles row update count for all UPDATEs that use a JOIN.
   423  type updateJoinRowHandler struct {
   424  	rowsMatched  int
   425  	rowsAffected int
   426  	joinSchema   sql.Schema
   427  	tableMap     map[string]sql.Schema // Needs to only be the tables that can be updated.
   428  	updaterMap   map[string]sql.RowUpdater
   429  }
   430  
   431  func (u *updateJoinRowHandler) handleRowUpdate(row sql.Row) error {
   432  	oldJoinRow := row[:len(row)/2]
   433  	newJoinRow := row[len(row)/2:]
   434  
   435  	tableToOldRow := plan.SplitRowIntoTableRowMap(oldJoinRow, u.joinSchema)
   436  	tableToNewRow := plan.SplitRowIntoTableRowMap(newJoinRow, u.joinSchema)
   437  
   438  	for tableName, _ := range u.updaterMap {
   439  		u.rowsMatched++ // TODO: This currently returns the incorrect answer
   440  		tableOldRow := tableToOldRow[tableName]
   441  		tableNewRow := tableToNewRow[tableName]
   442  		if equals, err := tableOldRow.Equals(tableNewRow, u.tableMap[tableName]); err == nil {
   443  			if !equals {
   444  				u.rowsAffected++
   445  			}
   446  		} else {
   447  			return err
   448  		}
   449  	}
   450  	return nil
   451  }
   452  
   453  func (u *updateJoinRowHandler) okResult() types.OkResult {
   454  	return types.OkResult{
   455  		RowsAffected: uint64(u.rowsAffected),
   456  		Info: plan.UpdateInfo{
   457  			Matched:  u.rowsMatched,
   458  			Updated:  u.rowsAffected,
   459  			Warnings: 0,
   460  		},
   461  	}
   462  }
   463  
   464  func (u *updateJoinRowHandler) RowsMatched() int64 {
   465  	return int64(u.rowsMatched)
   466  }
   467  
   468  type deleteRowHandler struct {
   469  	rowsAffected int
   470  }
   471  
   472  func (u *deleteRowHandler) handleRowUpdate(row sql.Row) error {
   473  	u.rowsAffected++
   474  	return nil
   475  }
   476  
   477  func (u *deleteRowHandler) okResult() types.OkResult {
   478  	return types.NewOkResult(u.rowsAffected)
   479  }
   480  
   481  type accumulatorIter struct {
   482  	iter             sql.RowIter
   483  	once             sync.Once
   484  	updateRowHandler accumulatorRowHandler
   485  }
   486  
   487  func (a *accumulatorIter) Next(ctx *sql.Context) (r sql.Row, err error) {
   488  	run := false
   489  	a.once.Do(func() {
   490  		run = true
   491  	})
   492  
   493  	if !run {
   494  		return nil, io.EOF
   495  	}
   496  
   497  	oldLastInsertId := ctx.Session.GetLastQueryInfo(sql.LastInsertId)
   498  	if oldLastInsertId != 0 {
   499  		ctx.Session.SetLastQueryInfo(sql.LastInsertId, -1)
   500  	}
   501  
   502  	// We close our child iterator before returning any results. In
   503  	// particular, the LOAD DATA source iterator needs to be closed before
   504  	// results are returned.
   505  	defer func() {
   506  		cerr := a.iter.Close(ctx)
   507  		if err == nil {
   508  			err = cerr
   509  		}
   510  	}()
   511  
   512  	for {
   513  		row, err := a.iter.Next(ctx)
   514  		igErr, isIg := err.(sql.IgnorableError)
   515  		select {
   516  		case <-ctx.Done():
   517  			return nil, ctx.Err()
   518  		default:
   519  		}
   520  		if err == io.EOF {
   521  			// TODO: The information flow here is pretty gnarly. We
   522  			// set some session variables based on the result, and
   523  			// we actually use a session variable to set
   524  			// InsertID. This should be improved.
   525  
   526  			// UPDATE statements also set FoundRows to the number of rows that
   527  			// matched the WHERE clause, same as a SELECT.
   528  			if ma, ok := a.updateRowHandler.(matchingAccumulator); ok {
   529  				ctx.SetLastQueryInfo(sql.FoundRows, ma.RowsMatched())
   530  			}
   531  
   532  			newLastInsertId := ctx.Session.GetLastQueryInfo(sql.LastInsertId)
   533  			if newLastInsertId == -1 {
   534  				ctx.Session.SetLastQueryInfo(sql.LastInsertId, oldLastInsertId)
   535  			}
   536  
   537  			res := a.updateRowHandler.okResult() // TODO: Should add warnings here
   538  
   539  			// For some update accumulators, we don't accurately track the last insert ID in the handler and need to set
   540  			// it manually in the result by getting it from the session. This doesn't work correctly in all cases and needs
   541  			// to be fixed. See comment in buildRowUpdateAccumulator in rowexec/dml.go
   542  			switch a.updateRowHandler.(type) {
   543  			case *onDuplicateUpdateHandler, *replaceRowHandler:
   544  				res.InsertID = uint64(newLastInsertId)
   545  			}
   546  
   547  			// By definition, ROW_COUNT() is equal to RowsAffected.
   548  			ctx.SetLastQueryInfo(sql.RowCount, int64(res.RowsAffected))
   549  
   550  			return sql.NewRow(res), nil
   551  		} else if isIg {
   552  			if ui, ok := a.updateRowHandler.(updateIgnoreAccumulatorRowHandler); ok {
   553  				err = ui.handleRowUpdateWithIgnore(igErr.OffendingRow, true)
   554  				if err != nil {
   555  					return nil, err
   556  				}
   557  			}
   558  		} else if err != nil {
   559  			return nil, err
   560  		} else {
   561  			err = a.updateRowHandler.handleRowUpdate(row)
   562  			if err != nil {
   563  				return nil, err
   564  			}
   565  		}
   566  	}
   567  }
   568  
   569  func (a *accumulatorIter) Close(ctx *sql.Context) error {
   570  	return nil
   571  }
   572  
   573  type matchingAccumulator interface {
   574  	RowsMatched() int64
   575  }
   576  
   577  type updateSourceIter struct {
   578  	childIter   sql.RowIter
   579  	updateExprs []sql.Expression
   580  	tableSchema sql.Schema
   581  	ignore      bool
   582  }
   583  
   584  func (u *updateSourceIter) Next(ctx *sql.Context) (sql.Row, error) {
   585  	oldRow, err := u.childIter.Next(ctx)
   586  	if err != nil {
   587  		return nil, err
   588  	}
   589  
   590  	newRow, err := applyUpdateExpressionsWithIgnore(ctx, u.updateExprs, u.tableSchema, oldRow, u.ignore)
   591  	if err != nil {
   592  		return nil, err
   593  	}
   594  
   595  	// Reduce the row to the length of the schema. The length can differ when some update values come from an outer
   596  	// scope, which will be the first N values in the row.
   597  	// TODO: handle this in the analyzer instead?
   598  	expectedSchemaLen := len(u.tableSchema)
   599  	if expectedSchemaLen < len(oldRow) {
   600  		oldRow = oldRow[len(oldRow)-expectedSchemaLen:]
   601  		newRow = newRow[len(newRow)-expectedSchemaLen:]
   602  	}
   603  
   604  	return oldRow.Append(newRow), nil
   605  }
   606  
   607  func (u *updateSourceIter) Close(ctx *sql.Context) error {
   608  	return u.childIter.Close(ctx)
   609  }