github.com/dolthub/go-mysql-server@v0.18.0/sql/rowexec/update.go (about)

     1  // Copyright 2023 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package rowexec
    16  
    17  import (
    18  	"errors"
    19  	"fmt"
    20  
    21  	"github.com/dolthub/go-mysql-server/sql"
    22  	"github.com/dolthub/go-mysql-server/sql/plan"
    23  )
    24  
    25  type updateIter struct {
    26  	childIter sql.RowIter
    27  	schema    sql.Schema
    28  	updater   sql.RowUpdater
    29  	checks    sql.CheckConstraints
    30  	closed    bool
    31  	ignore    bool
    32  }
    33  
    34  func (u *updateIter) Next(ctx *sql.Context) (sql.Row, error) {
    35  	oldAndNewRow, err := u.childIter.Next(ctx)
    36  	if err != nil {
    37  		return nil, err
    38  	}
    39  
    40  	oldRow, newRow := oldAndNewRow[:len(oldAndNewRow)/2], oldAndNewRow[len(oldAndNewRow)/2:]
    41  	if equals, err := oldRow.Equals(newRow, u.schema); err == nil {
    42  		if !equals {
    43  			// apply check constraints
    44  			for _, check := range u.checks {
    45  				if !check.Enforced {
    46  					continue
    47  				}
    48  
    49  				res, err := sql.EvaluateCondition(ctx, check.Expr, newRow)
    50  				if err != nil {
    51  					return nil, err
    52  				}
    53  
    54  				if sql.IsFalse(res) {
    55  					return nil, u.ignoreOrError(ctx, newRow, sql.ErrCheckConstraintViolated.New(check.Name))
    56  				}
    57  			}
    58  
    59  			err := u.validateNullability(ctx, newRow, u.schema)
    60  			if err != nil {
    61  				return nil, u.ignoreOrError(ctx, newRow, err)
    62  			}
    63  
    64  			err = u.updater.Update(ctx, oldRow, newRow)
    65  			if err != nil {
    66  				return nil, u.ignoreOrError(ctx, newRow, err)
    67  			}
    68  		}
    69  	} else {
    70  		return nil, err
    71  	}
    72  
    73  	return oldAndNewRow, nil
    74  }
    75  
    76  // Applies the update expressions given to the row given, returning the new resultant row. In the case that ignore is
    77  // provided and there is a type conversion error, this function sets the value to the zero value as per the MySQL standard.
    78  func applyUpdateExpressionsWithIgnore(ctx *sql.Context, updateExprs []sql.Expression, tableSchema sql.Schema, row sql.Row, ignore bool) (sql.Row, error) {
    79  	var secondPass []int
    80  
    81  	for i, updateExpr := range updateExprs {
    82  		defaultVal, isDefaultVal := defaultValFromSetExpression(updateExpr)
    83  		// Any generated columns must be projected into place so that the caller gets their newest values as well. We
    84  		// do this in a second pass as necessary.
    85  		if isDefaultVal && !defaultVal.IsLiteral() {
    86  			secondPass = append(secondPass, i)
    87  			continue
    88  		}
    89  
    90  		val, err := updateExpr.Eval(ctx, row)
    91  		if err != nil {
    92  			var wtce sql.WrappedTypeConversionError
    93  			isTypeConversionError := errors.As(err, &wtce)
    94  			if !isTypeConversionError || !ignore {
    95  				return nil, err
    96  			}
    97  
    98  			cpy := row.Copy()
    99  			cpy[wtce.OffendingIdx] = wtce.OffendingVal // Needed for strings
   100  			val = convertDataAndWarn(ctx, tableSchema, cpy, wtce.OffendingIdx, wtce.Err)
   101  		}
   102  		var ok bool
   103  		row, ok = val.(sql.Row)
   104  		if !ok {
   105  			return nil, plan.ErrUpdateUnexpectedSetResult.New(val)
   106  		}
   107  	}
   108  
   109  	for _, index := range secondPass {
   110  		val, err := updateExprs[index].Eval(ctx, row)
   111  		if err != nil {
   112  			return nil, err
   113  		}
   114  
   115  		var ok bool
   116  		row, ok = val.(sql.Row)
   117  		if !ok {
   118  			return nil, plan.ErrUpdateUnexpectedSetResult.New(val)
   119  		}
   120  	}
   121  
   122  	return row, nil
   123  }
   124  
   125  func (u *updateIter) validateNullability(ctx *sql.Context, row sql.Row, schema sql.Schema) error {
   126  	for idx := 0; idx < len(row); idx++ {
   127  		col := schema[idx]
   128  		if !col.Nullable && row[idx] == nil {
   129  			// In the case of an IGNORE we set the nil value to a default and add a warning
   130  			if u.ignore {
   131  				row[idx] = col.Type.Zero()
   132  				_ = warnOnIgnorableError(ctx, row, sql.ErrInsertIntoNonNullableProvidedNull.New(col.Name)) // will always return nil
   133  			} else {
   134  				return sql.ErrInsertIntoNonNullableProvidedNull.New(col.Name)
   135  			}
   136  
   137  		}
   138  	}
   139  	return nil
   140  }
   141  
   142  func (u *updateIter) Close(ctx *sql.Context) error {
   143  	if !u.closed {
   144  		u.closed = true
   145  		if err := u.updater.Close(ctx); err != nil {
   146  			return err
   147  		}
   148  		return u.childIter.Close(ctx)
   149  	}
   150  	return nil
   151  }
   152  
   153  func (u *updateIter) ignoreOrError(ctx *sql.Context, row sql.Row, err error) error {
   154  	if !u.ignore {
   155  		return err
   156  	}
   157  
   158  	return warnOnIgnorableError(ctx, row, err)
   159  }
   160  
   161  func newUpdateIter(
   162  	childIter sql.RowIter,
   163  	schema sql.Schema,
   164  	updater sql.RowUpdater,
   165  	checks sql.CheckConstraints,
   166  	ignore bool,
   167  ) sql.RowIter {
   168  	if ignore {
   169  		return plan.NewCheckpointingTableEditorIter(&updateIter{
   170  			childIter: childIter,
   171  			updater:   updater,
   172  			schema:    schema,
   173  			checks:    checks,
   174  			ignore:    true,
   175  		}, updater)
   176  	} else {
   177  		return plan.NewTableEditorIter(&updateIter{
   178  			childIter: childIter,
   179  			updater:   updater,
   180  			schema:    schema,
   181  			checks:    checks,
   182  		}, updater)
   183  	}
   184  }
   185  
   186  // updateJoinIter wraps the child UpdateSource projectIter and returns join row in such a way that updates per table row are
   187  // done once.
   188  type updateJoinIter struct {
   189  	updateSourceIter sql.RowIter
   190  	joinSchema       sql.Schema
   191  	updaters         map[string]sql.RowUpdater
   192  	caches           map[string]sql.KeyValueCache
   193  	disposals        map[string]sql.DisposeFunc
   194  	joinNode         sql.Node
   195  }
   196  
   197  var _ sql.RowIter = (*updateJoinIter)(nil)
   198  
   199  func (u *updateJoinIter) Next(ctx *sql.Context) (sql.Row, error) {
   200  	for {
   201  		oldAndNewRow, err := u.updateSourceIter.Next(ctx)
   202  		if err != nil {
   203  			return nil, err
   204  		}
   205  
   206  		oldJoinRow, newJoinRow := oldAndNewRow[:len(oldAndNewRow)/2], oldAndNewRow[len(oldAndNewRow)/2:]
   207  
   208  		tableToOldRowMap := plan.SplitRowIntoTableRowMap(oldJoinRow, u.joinSchema)
   209  		tableToNewRowMap := plan.SplitRowIntoTableRowMap(newJoinRow, u.joinSchema)
   210  
   211  		for tableName, _ := range u.updaters {
   212  			oldTableRow := tableToOldRowMap[tableName]
   213  
   214  			// Handle the case of row being ignored due to it not being valid in the join row.
   215  			if isRightOrLeftJoin(u.joinNode) {
   216  				works, err := u.shouldUpdateDirectionalJoin(ctx, oldJoinRow, oldTableRow)
   217  				if err != nil {
   218  					return nil, err
   219  				}
   220  
   221  				if !works {
   222  					// rewrite the newJoinRow to ensure an update does not happen
   223  					tableToNewRowMap[tableName] = oldTableRow
   224  					continue
   225  				}
   226  			}
   227  
   228  			// Determine whether this row in the table has already been updated
   229  			cache := u.getOrCreateCache(ctx, tableName)
   230  			hash, err := sql.HashOf(oldTableRow)
   231  			if err != nil {
   232  				return nil, err
   233  			}
   234  
   235  			_, err = cache.Get(hash)
   236  			if errors.Is(err, sql.ErrKeyNotFound) {
   237  				cache.Put(hash, struct{}{})
   238  				continue
   239  			} else if err != nil {
   240  				return nil, err
   241  			}
   242  
   243  			// If this row for the table has already been updated we rewrite the newJoinRow counterpart to ensure that this
   244  			// returned row is not incorrectly counted by the update accumulator.
   245  			tableToNewRowMap[tableName] = oldTableRow
   246  		}
   247  
   248  		newJoinRow = recreateRowFromMap(tableToNewRowMap, u.joinSchema)
   249  		equals, err := oldJoinRow.Equals(newJoinRow, u.joinSchema)
   250  		if err != nil {
   251  			return nil, err
   252  		}
   253  		if !equals {
   254  			return append(oldJoinRow, newJoinRow...), nil
   255  		}
   256  	}
   257  }
   258  
   259  func toJoinNode(node sql.Node) *plan.JoinNode {
   260  	switch n := node.(type) {
   261  	case *plan.JoinNode:
   262  		return n
   263  	case *plan.TopN:
   264  		return toJoinNode(n.Child)
   265  	case *plan.Filter:
   266  		return toJoinNode(n.Child)
   267  	case *plan.Project:
   268  		return toJoinNode(n.Child)
   269  	case *plan.Limit:
   270  		return toJoinNode(n.Child)
   271  	case *plan.Offset:
   272  		return toJoinNode(n.Child)
   273  	case *plan.Sort:
   274  		return toJoinNode(n.Child)
   275  	case *plan.Distinct:
   276  		return toJoinNode(n.Child)
   277  	case *plan.Having:
   278  		return toJoinNode(n.Child)
   279  	case *plan.Window:
   280  		return toJoinNode(n.Child)
   281  	default:
   282  		return nil
   283  	}
   284  }
   285  
   286  func isIndexedAccess(node sql.Node) bool {
   287  	switch n := node.(type) {
   288  	case *plan.Filter:
   289  		return isIndexedAccess(n.Child)
   290  	case *plan.TableAlias:
   291  		return isIndexedAccess(n.Child)
   292  	case *plan.JoinNode:
   293  		return isIndexedAccess(n.Left())
   294  	case *plan.IndexedTableAccess:
   295  		return true
   296  	}
   297  	return false
   298  }
   299  
   300  func isRightOrLeftJoin(node sql.Node) bool {
   301  	jn := toJoinNode(node)
   302  	if jn == nil {
   303  		return false
   304  	}
   305  	return jn.JoinType().IsLeftOuter()
   306  }
   307  
   308  // shouldUpdateDirectionalJoin determines whether a table row should be updated in the context of a large right/left join row.
   309  // A table row should only be updated if 1) It fits the join conditions (the intersection of the join) 2) It fits only
   310  // the left or right side of the join (given the direction). A row of all nils that does not pass condition 1 must not
   311  // be part of the update operation. This is follows the logic as established in the joinIter.
   312  func (u *updateJoinIter) shouldUpdateDirectionalJoin(ctx *sql.Context, joinRow, tableRow sql.Row) (bool, error) {
   313  	jn := toJoinNode(u.joinNode)
   314  	if jn == nil || !jn.JoinType().IsLeftOuter() {
   315  		return true, fmt.Errorf("expected left join")
   316  	}
   317  
   318  	// If the overall row fits the join condition it is fine (i.e. middle of the venn diagram).
   319  	val, err := jn.JoinCond().Eval(ctx, joinRow)
   320  	if err != nil {
   321  		return true, err
   322  	}
   323  	if v, ok := val.(bool); ok && v && !isIndexedAccess(jn) {
   324  		return true, nil
   325  	}
   326  
   327  	for _, v := range tableRow {
   328  		if v != nil {
   329  			return true, nil
   330  		}
   331  	}
   332  
   333  	// If the row is all nils we know it should not be updated as per the function description.
   334  	return false, nil
   335  }
   336  
   337  func (u *updateJoinIter) Close(context *sql.Context) error {
   338  	for _, disposeF := range u.disposals {
   339  		disposeF()
   340  	}
   341  
   342  	return u.updateSourceIter.Close(context)
   343  }
   344  
   345  func (u *updateJoinIter) getOrCreateCache(ctx *sql.Context, tableName string) sql.KeyValueCache {
   346  	potential, exists := u.caches[tableName]
   347  	if exists {
   348  		return potential
   349  	}
   350  
   351  	cache, disposal := ctx.Memory.NewHistoryCache()
   352  	u.caches[tableName] = cache
   353  	u.disposals[tableName] = disposal
   354  
   355  	return cache
   356  }
   357  
   358  // recreateRowFromMap takes a join schema and row map and recreates the original join row.
   359  func recreateRowFromMap(rowMap map[string]sql.Row, joinSchema sql.Schema) sql.Row {
   360  	var ret sql.Row
   361  
   362  	if len(joinSchema) == 0 {
   363  		return ret
   364  	}
   365  
   366  	currentTable := joinSchema[0].Source
   367  	ret = append(ret, rowMap[currentTable]...)
   368  
   369  	for i := 1; i < len(joinSchema); i++ {
   370  		c := joinSchema[i]
   371  
   372  		if c.Source != currentTable {
   373  			ret = append(ret, rowMap[c.Source]...)
   374  			currentTable = c.Source
   375  		}
   376  	}
   377  
   378  	return ret
   379  }