github.com/dolthub/go-mysql-server@v0.18.0/sql/rowexec/insert.go (about)

     1  // Copyright 2023 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package rowexec
    16  
    17  import (
    18  	"fmt"
    19  	"io"
    20  
    21  	"github.com/dolthub/vitess/go/vt/proto/query"
    22  	"gopkg.in/src-d/go-errors.v1"
    23  
    24  	"github.com/dolthub/go-mysql-server/sql"
    25  	"github.com/dolthub/go-mysql-server/sql/expression"
    26  	"github.com/dolthub/go-mysql-server/sql/expression/function"
    27  	"github.com/dolthub/go-mysql-server/sql/plan"
    28  	"github.com/dolthub/go-mysql-server/sql/transform"
    29  	"github.com/dolthub/go-mysql-server/sql/types"
    30  )
    31  
    32  type insertIter struct {
    33  	schema              sql.Schema
    34  	inserter            sql.RowInserter
    35  	replacer            sql.RowReplacer
    36  	updater             sql.RowUpdater
    37  	rowSource           sql.RowIter
    38  	lastInsertIdUpdated bool
    39  	hasAutoAutoIncValue bool
    40  	ctx                 *sql.Context
    41  	insertExprs         []sql.Expression
    42  	updateExprs         []sql.Expression
    43  	checks              sql.CheckConstraints
    44  	tableNode           sql.Node
    45  	closed              bool
    46  	ignore              bool
    47  }
    48  
    49  func getInsertExpressions(values sql.Node) []sql.Expression {
    50  	var exprs []sql.Expression
    51  	transform.Inspect(values, func(node sql.Node) bool {
    52  		switch node := node.(type) {
    53  		case *plan.Project:
    54  			exprs = node.Projections
    55  			return false
    56  		}
    57  		return true
    58  	})
    59  	return exprs
    60  }
    61  
    62  func (i *insertIter) Next(ctx *sql.Context) (returnRow sql.Row, returnErr error) {
    63  	row, err := i.rowSource.Next(ctx)
    64  	if err == io.EOF {
    65  		return nil, err
    66  	}
    67  
    68  	if err != nil {
    69  		return nil, i.ignoreOrClose(ctx, row, err)
    70  	}
    71  
    72  	// Prune the row down to the size of the schema. It can be larger in the case of running with an outer scope, in which
    73  	// case the additional scope variables are prepended to the row.
    74  	if len(row) > len(i.schema) {
    75  		row = row[len(row)-len(i.schema):]
    76  	}
    77  
    78  	err = i.validateNullability(ctx, i.schema, row)
    79  	if err != nil {
    80  		return nil, i.ignoreOrClose(ctx, row, err)
    81  	}
    82  
    83  	err = i.evaluateChecks(ctx, row)
    84  	if err != nil {
    85  		return nil, i.ignoreOrClose(ctx, row, err)
    86  	}
    87  
    88  	origRow := make(sql.Row, len(row))
    89  	copy(origRow, row)
    90  
    91  	// Do any necessary type conversions to the target schema
    92  	for idx, col := range i.schema {
    93  		if row[idx] != nil {
    94  			converted, inRange, cErr := col.Type.Convert(row[idx])
    95  			if cErr == nil && !inRange {
    96  				cErr = sql.ErrValueOutOfRange.New(row[idx], col.Type)
    97  			}
    98  			if cErr != nil {
    99  				// Ignore individual column errors when INSERT IGNORE, UPDATE IGNORE, etc. is specified.
   100  				// For JSON column types, always throw an error. MySQL throws the following error even when
   101  				// IGNORE is specified:
   102  				// ERROR 3140 (22032): Invalid JSON text: "Invalid value." at position 0 in value for column
   103  				// 'table.column'.
   104  				if i.ignore && col.Type.Type() != query.Type_JSON {
   105  					if _, ok := col.Type.(sql.NumberType); ok {
   106  						if converted == nil {
   107  							converted = i.schema[idx].Type.Zero()
   108  						}
   109  						row[idx] = converted
   110  						// Add a warning instead
   111  						ctx.Session.Warn(&sql.Warning{
   112  							Level:   "Note",
   113  							Code:    sql.CastSQLError(cErr).Num,
   114  							Message: cErr.Error(),
   115  						})
   116  					} else {
   117  						row = convertDataAndWarn(ctx, i.schema, row, idx, cErr)
   118  					}
   119  					continue
   120  				} else {
   121  					// Fill in error with information
   122  					if types.ErrLengthBeyondLimit.Is(cErr) {
   123  						cErr = types.ErrLengthBeyondLimit.New(row[idx], col.Name)
   124  					} else if sql.ErrNotMatchingSRID.Is(cErr) {
   125  						cErr = sql.ErrNotMatchingSRIDWithColName.New(col.Name, cErr)
   126  					}
   127  					return nil, sql.NewWrappedInsertError(origRow, cErr)
   128  				}
   129  			}
   130  			row[idx] = converted
   131  		}
   132  	}
   133  
   134  	if i.replacer != nil {
   135  		toReturn := make(sql.Row, len(row)*2)
   136  		for i := 0; i < len(row); i++ {
   137  			toReturn[i+len(row)] = row[i]
   138  		}
   139  		// May have multiple duplicate pk & unique errors due to multiple indexes
   140  		//TODO: how does this interact with triggers?
   141  		for {
   142  			if err := i.replacer.Insert(ctx, row); err != nil {
   143  				if !sql.ErrPrimaryKeyViolation.Is(err) && !sql.ErrUniqueKeyViolation.Is(err) {
   144  					i.rowSource.Close(ctx)
   145  					i.rowSource = nil
   146  					return nil, sql.NewWrappedInsertError(row, err)
   147  				}
   148  
   149  				ue := err.(*errors.Error).Cause().(sql.UniqueKeyError)
   150  				if err = i.replacer.Delete(ctx, ue.Existing); err != nil {
   151  					i.rowSource.Close(ctx)
   152  					i.rowSource = nil
   153  					return nil, sql.NewWrappedInsertError(row, err)
   154  				}
   155  				// the row had to be deleted, write the values into the toReturn row
   156  				copy(toReturn, ue.Existing)
   157  			} else {
   158  				break
   159  			}
   160  		}
   161  		return toReturn, nil
   162  	} else {
   163  		if err := i.inserter.Insert(ctx, row); err != nil {
   164  			if (!sql.ErrPrimaryKeyViolation.Is(err) && !sql.ErrUniqueKeyViolation.Is(err) && !sql.ErrDuplicateEntry.Is(err)) || len(i.updateExprs) == 0 {
   165  				return nil, i.ignoreOrClose(ctx, row, err)
   166  			}
   167  
   168  			ue := err.(*errors.Error).Cause().(sql.UniqueKeyError)
   169  			return i.handleOnDuplicateKeyUpdate(ctx, ue.Existing, row)
   170  		}
   171  	}
   172  
   173  	i.updateLastInsertId(ctx, row)
   174  
   175  	return row, nil
   176  }
   177  
   178  func (i *insertIter) handleOnDuplicateKeyUpdate(ctx *sql.Context, oldRow, newRow sql.Row) (returnRow sql.Row, returnErr error) {
   179  	var err error
   180  	updateAcc := append(oldRow, newRow...)
   181  	var evalRow sql.Row
   182  	for _, updateExpr := range i.updateExprs {
   183  		// this SET <val> indexes into LHS, but the <expr> can
   184  		// reference the new row on RHS
   185  		val, err := updateExpr.Eval(i.ctx, updateAcc)
   186  		if err != nil {
   187  			if i.ignore {
   188  				idx, ok := getFieldIndexFromUpdateExpr(updateExpr)
   189  				if !ok {
   190  					return nil, err
   191  				}
   192  
   193  				val = convertDataAndWarn(ctx, i.schema, newRow, idx, err)
   194  			} else {
   195  				return nil, err
   196  			}
   197  		}
   198  
   199  		updateAcc = val.(sql.Row)
   200  	}
   201  	// project LHS only
   202  	evalRow = updateAcc[:len(oldRow)]
   203  
   204  	// Should revaluate the check conditions.
   205  	err = i.evaluateChecks(ctx, evalRow)
   206  	if err != nil {
   207  		return nil, i.ignoreOrClose(ctx, newRow, err)
   208  	}
   209  
   210  	err = i.updater.Update(ctx, oldRow, evalRow)
   211  	if err != nil {
   212  		return nil, i.ignoreOrClose(ctx, newRow, err)
   213  	}
   214  
   215  	// In the case that we attempted an update, return a concatenated [old,new] row just like update.
   216  	return oldRow.Append(evalRow), nil
   217  }
   218  
   219  func getFieldIndexFromUpdateExpr(updateExpr sql.Expression) (int, bool) {
   220  	setField, ok := updateExpr.(*expression.SetField)
   221  	if !ok {
   222  		return 0, false
   223  	}
   224  
   225  	getField, ok := setField.LeftChild.(*expression.GetField)
   226  	if !ok {
   227  		return 0, false
   228  	}
   229  
   230  	return getField.Index(), true
   231  }
   232  
   233  // resolveValues resolves all VALUES functions.
   234  func (i *insertIter) resolveValues(ctx *sql.Context, insertRow sql.Row) error {
   235  	for _, updateExpr := range i.updateExprs {
   236  		var err error
   237  		sql.Inspect(updateExpr, func(expr sql.Expression) bool {
   238  			valuesExpr, ok := expr.(*function.Values)
   239  			if !ok {
   240  				return true
   241  			}
   242  			getField, ok := valuesExpr.Child.(*expression.GetField)
   243  			if !ok {
   244  				err = fmt.Errorf("VALUES functions may only contain column names")
   245  				return false
   246  			}
   247  			valuesExpr.Value = insertRow[getField.Index()]
   248  			return false
   249  		})
   250  		if err != nil {
   251  			return err
   252  		}
   253  	}
   254  	return nil
   255  }
   256  
   257  func (i *insertIter) Close(ctx *sql.Context) error {
   258  	if !i.closed {
   259  		i.closed = true
   260  		var rsErr, iErr, rErr, uErr error
   261  		if i.rowSource != nil {
   262  			rsErr = i.rowSource.Close(ctx)
   263  		}
   264  		if i.inserter != nil {
   265  			iErr = i.inserter.Close(ctx)
   266  		}
   267  		if i.replacer != nil {
   268  			rErr = i.replacer.Close(ctx)
   269  		}
   270  		if i.updater != nil {
   271  			uErr = i.updater.Close(ctx)
   272  		}
   273  		if rsErr != nil {
   274  			return rsErr
   275  		}
   276  		if iErr != nil {
   277  			return iErr
   278  		}
   279  		if rErr != nil {
   280  			return rErr
   281  		}
   282  		if uErr != nil {
   283  			return uErr
   284  		}
   285  	}
   286  	return nil
   287  }
   288  
   289  func (i *insertIter) updateLastInsertId(ctx *sql.Context, row sql.Row) {
   290  	if i.lastInsertIdUpdated {
   291  		return
   292  	}
   293  
   294  	autoIncVal := i.getAutoIncVal(row)
   295  
   296  	if i.hasAutoAutoIncValue {
   297  		ctx.SetLastQueryInfo(sql.LastInsertId, autoIncVal)
   298  		i.lastInsertIdUpdated = true
   299  	}
   300  }
   301  
   302  func (i *insertIter) getAutoIncVal(row sql.Row) int64 {
   303  	var autoIncVal int64
   304  	for i, expr := range i.insertExprs {
   305  		if _, ok := expr.(*expression.AutoIncrement); ok {
   306  			autoIncVal = toInt64(row[i])
   307  			break
   308  		}
   309  	}
   310  	return autoIncVal
   311  }
   312  
   313  func (i *insertIter) ignoreOrClose(ctx *sql.Context, row sql.Row, err error) error {
   314  	if !i.ignore {
   315  		return sql.NewWrappedInsertError(row, err)
   316  	}
   317  
   318  	return warnOnIgnorableError(ctx, row, err)
   319  }
   320  
   321  // convertDataAndWarn modifies a row with data conversion issues in INSERT/UPDATE IGNORE calls
   322  // Per MySQL docs "Rows set to values that would cause data conversion errors are set to the closest valid values instead"
   323  // cc. https://dev.mysql.com/doc/refman/8.0/en/sql-mode.html#sql-mode-strict
   324  func convertDataAndWarn(ctx *sql.Context, tableSchema sql.Schema, row sql.Row, columnIdx int, err error) sql.Row {
   325  	if types.ErrLengthBeyondLimit.Is(err) {
   326  		maxLength := tableSchema[columnIdx].Type.(sql.StringType).MaxCharacterLength()
   327  		row[columnIdx] = row[columnIdx].(string)[:maxLength] // truncate string
   328  	} else {
   329  		row[columnIdx] = tableSchema[columnIdx].Type.Zero()
   330  	}
   331  
   332  	sqlerr := sql.CastSQLError(err)
   333  
   334  	// Add a warning instead
   335  	if ctx != nil && ctx.Session != nil {
   336  		ctx.Session.Warn(&sql.Warning{
   337  			Level:   "Note",
   338  			Code:    sqlerr.Num,
   339  			Message: err.Error(),
   340  		})
   341  	}
   342  
   343  	return row
   344  }
   345  
   346  func warnOnIgnorableError(ctx *sql.Context, row sql.Row, err error) error {
   347  	// Check that this error is a part of the list of Ignorable Errors and create the relevant warning
   348  	for _, ie := range plan.IgnorableErrors {
   349  		if ie.Is(err) {
   350  			sqlerr := sql.CastSQLError(err)
   351  
   352  			// Add a warning instead
   353  			if ctx != nil && ctx.Session != nil {
   354  				ctx.Session.Warn(&sql.Warning{
   355  					Level:   "Note",
   356  					Code:    sqlerr.Num,
   357  					Message: err.Error(),
   358  				})
   359  			}
   360  
   361  			// In this case the default value gets updated so return nil
   362  			if sql.ErrInsertIntoNonNullableDefaultNullColumn.Is(err) {
   363  				return nil
   364  			}
   365  
   366  			// Return the InsertIgnore err to ensure our accumulator doesn't count this row.
   367  			return sql.NewIgnorableError(row)
   368  		}
   369  	}
   370  
   371  	return err
   372  }
   373  
   374  func (i *insertIter) evaluateChecks(ctx *sql.Context, row sql.Row) error {
   375  	for _, check := range i.checks {
   376  		if !check.Enforced {
   377  			continue
   378  		}
   379  
   380  		res, err := sql.EvaluateCondition(ctx, check.Expr, row)
   381  
   382  		if err != nil {
   383  			return err
   384  		}
   385  
   386  		if sql.IsFalse(res) {
   387  			return sql.ErrCheckConstraintViolated.New(check.Name)
   388  		}
   389  	}
   390  
   391  	return nil
   392  }
   393  
   394  func (i *insertIter) validateNullability(ctx *sql.Context, dstSchema sql.Schema, row sql.Row) error {
   395  	for count, col := range dstSchema {
   396  		if !col.Nullable && row[count] == nil {
   397  			// In the case of an IGNORE we set the nil value to a default and add a warning
   398  			if i.ignore {
   399  				row[count] = col.Type.Zero()
   400  				_ = warnOnIgnorableError(ctx, row, sql.ErrInsertIntoNonNullableProvidedNull.New(col.Name)) // will always return nil
   401  			} else {
   402  				return sql.ErrInsertIntoNonNullableProvidedNull.New(col.Name)
   403  			}
   404  		}
   405  	}
   406  	return nil
   407  }
   408  
   409  func toInt64(x interface{}) int64 {
   410  	switch x := x.(type) {
   411  	case int:
   412  		return int64(x)
   413  	case uint:
   414  		return int64(x)
   415  	case int8:
   416  		return int64(x)
   417  	case uint8:
   418  		return int64(x)
   419  	case int16:
   420  		return int64(x)
   421  	case uint16:
   422  		return int64(x)
   423  	case int32:
   424  		return int64(x)
   425  	case uint32:
   426  		return int64(x)
   427  	case int64:
   428  		return x
   429  	case uint64:
   430  		return int64(x)
   431  	case float32:
   432  		return int64(x)
   433  	case float64:
   434  		return int64(x)
   435  	default:
   436  		panic(fmt.Sprintf("Expected a numeric auto increment value, but got %T", x))
   437  	}
   438  }