github.com/dolthub/go-mysql-server@v0.18.0/sql/rowexec/insert.go (about) 1 // Copyright 2023 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package rowexec 16 17 import ( 18 "fmt" 19 "io" 20 21 "github.com/dolthub/vitess/go/vt/proto/query" 22 "gopkg.in/src-d/go-errors.v1" 23 24 "github.com/dolthub/go-mysql-server/sql" 25 "github.com/dolthub/go-mysql-server/sql/expression" 26 "github.com/dolthub/go-mysql-server/sql/expression/function" 27 "github.com/dolthub/go-mysql-server/sql/plan" 28 "github.com/dolthub/go-mysql-server/sql/transform" 29 "github.com/dolthub/go-mysql-server/sql/types" 30 ) 31 32 type insertIter struct { 33 schema sql.Schema 34 inserter sql.RowInserter 35 replacer sql.RowReplacer 36 updater sql.RowUpdater 37 rowSource sql.RowIter 38 lastInsertIdUpdated bool 39 hasAutoAutoIncValue bool 40 ctx *sql.Context 41 insertExprs []sql.Expression 42 updateExprs []sql.Expression 43 checks sql.CheckConstraints 44 tableNode sql.Node 45 closed bool 46 ignore bool 47 } 48 49 func getInsertExpressions(values sql.Node) []sql.Expression { 50 var exprs []sql.Expression 51 transform.Inspect(values, func(node sql.Node) bool { 52 switch node := node.(type) { 53 case *plan.Project: 54 exprs = node.Projections 55 return false 56 } 57 return true 58 }) 59 return exprs 60 } 61 62 func (i *insertIter) Next(ctx *sql.Context) (returnRow sql.Row, returnErr error) { 63 row, err := i.rowSource.Next(ctx) 64 if err == io.EOF { 65 return nil, err 66 } 67 68 if err != nil { 69 return nil, i.ignoreOrClose(ctx, row, err) 70 } 71 72 // Prune the row down to the size of the schema. It can be larger in the case of running with an outer scope, in which 73 // case the additional scope variables are prepended to the row. 74 if len(row) > len(i.schema) { 75 row = row[len(row)-len(i.schema):] 76 } 77 78 err = i.validateNullability(ctx, i.schema, row) 79 if err != nil { 80 return nil, i.ignoreOrClose(ctx, row, err) 81 } 82 83 err = i.evaluateChecks(ctx, row) 84 if err != nil { 85 return nil, i.ignoreOrClose(ctx, row, err) 86 } 87 88 origRow := make(sql.Row, len(row)) 89 copy(origRow, row) 90 91 // Do any necessary type conversions to the target schema 92 for idx, col := range i.schema { 93 if row[idx] != nil { 94 converted, inRange, cErr := col.Type.Convert(row[idx]) 95 if cErr == nil && !inRange { 96 cErr = sql.ErrValueOutOfRange.New(row[idx], col.Type) 97 } 98 if cErr != nil { 99 // Ignore individual column errors when INSERT IGNORE, UPDATE IGNORE, etc. is specified. 100 // For JSON column types, always throw an error. MySQL throws the following error even when 101 // IGNORE is specified: 102 // ERROR 3140 (22032): Invalid JSON text: "Invalid value." at position 0 in value for column 103 // 'table.column'. 104 if i.ignore && col.Type.Type() != query.Type_JSON { 105 if _, ok := col.Type.(sql.NumberType); ok { 106 if converted == nil { 107 converted = i.schema[idx].Type.Zero() 108 } 109 row[idx] = converted 110 // Add a warning instead 111 ctx.Session.Warn(&sql.Warning{ 112 Level: "Note", 113 Code: sql.CastSQLError(cErr).Num, 114 Message: cErr.Error(), 115 }) 116 } else { 117 row = convertDataAndWarn(ctx, i.schema, row, idx, cErr) 118 } 119 continue 120 } else { 121 // Fill in error with information 122 if types.ErrLengthBeyondLimit.Is(cErr) { 123 cErr = types.ErrLengthBeyondLimit.New(row[idx], col.Name) 124 } else if sql.ErrNotMatchingSRID.Is(cErr) { 125 cErr = sql.ErrNotMatchingSRIDWithColName.New(col.Name, cErr) 126 } 127 return nil, sql.NewWrappedInsertError(origRow, cErr) 128 } 129 } 130 row[idx] = converted 131 } 132 } 133 134 if i.replacer != nil { 135 toReturn := make(sql.Row, len(row)*2) 136 for i := 0; i < len(row); i++ { 137 toReturn[i+len(row)] = row[i] 138 } 139 // May have multiple duplicate pk & unique errors due to multiple indexes 140 //TODO: how does this interact with triggers? 141 for { 142 if err := i.replacer.Insert(ctx, row); err != nil { 143 if !sql.ErrPrimaryKeyViolation.Is(err) && !sql.ErrUniqueKeyViolation.Is(err) { 144 i.rowSource.Close(ctx) 145 i.rowSource = nil 146 return nil, sql.NewWrappedInsertError(row, err) 147 } 148 149 ue := err.(*errors.Error).Cause().(sql.UniqueKeyError) 150 if err = i.replacer.Delete(ctx, ue.Existing); err != nil { 151 i.rowSource.Close(ctx) 152 i.rowSource = nil 153 return nil, sql.NewWrappedInsertError(row, err) 154 } 155 // the row had to be deleted, write the values into the toReturn row 156 copy(toReturn, ue.Existing) 157 } else { 158 break 159 } 160 } 161 return toReturn, nil 162 } else { 163 if err := i.inserter.Insert(ctx, row); err != nil { 164 if (!sql.ErrPrimaryKeyViolation.Is(err) && !sql.ErrUniqueKeyViolation.Is(err) && !sql.ErrDuplicateEntry.Is(err)) || len(i.updateExprs) == 0 { 165 return nil, i.ignoreOrClose(ctx, row, err) 166 } 167 168 ue := err.(*errors.Error).Cause().(sql.UniqueKeyError) 169 return i.handleOnDuplicateKeyUpdate(ctx, ue.Existing, row) 170 } 171 } 172 173 i.updateLastInsertId(ctx, row) 174 175 return row, nil 176 } 177 178 func (i *insertIter) handleOnDuplicateKeyUpdate(ctx *sql.Context, oldRow, newRow sql.Row) (returnRow sql.Row, returnErr error) { 179 var err error 180 updateAcc := append(oldRow, newRow...) 181 var evalRow sql.Row 182 for _, updateExpr := range i.updateExprs { 183 // this SET <val> indexes into LHS, but the <expr> can 184 // reference the new row on RHS 185 val, err := updateExpr.Eval(i.ctx, updateAcc) 186 if err != nil { 187 if i.ignore { 188 idx, ok := getFieldIndexFromUpdateExpr(updateExpr) 189 if !ok { 190 return nil, err 191 } 192 193 val = convertDataAndWarn(ctx, i.schema, newRow, idx, err) 194 } else { 195 return nil, err 196 } 197 } 198 199 updateAcc = val.(sql.Row) 200 } 201 // project LHS only 202 evalRow = updateAcc[:len(oldRow)] 203 204 // Should revaluate the check conditions. 205 err = i.evaluateChecks(ctx, evalRow) 206 if err != nil { 207 return nil, i.ignoreOrClose(ctx, newRow, err) 208 } 209 210 err = i.updater.Update(ctx, oldRow, evalRow) 211 if err != nil { 212 return nil, i.ignoreOrClose(ctx, newRow, err) 213 } 214 215 // In the case that we attempted an update, return a concatenated [old,new] row just like update. 216 return oldRow.Append(evalRow), nil 217 } 218 219 func getFieldIndexFromUpdateExpr(updateExpr sql.Expression) (int, bool) { 220 setField, ok := updateExpr.(*expression.SetField) 221 if !ok { 222 return 0, false 223 } 224 225 getField, ok := setField.LeftChild.(*expression.GetField) 226 if !ok { 227 return 0, false 228 } 229 230 return getField.Index(), true 231 } 232 233 // resolveValues resolves all VALUES functions. 234 func (i *insertIter) resolveValues(ctx *sql.Context, insertRow sql.Row) error { 235 for _, updateExpr := range i.updateExprs { 236 var err error 237 sql.Inspect(updateExpr, func(expr sql.Expression) bool { 238 valuesExpr, ok := expr.(*function.Values) 239 if !ok { 240 return true 241 } 242 getField, ok := valuesExpr.Child.(*expression.GetField) 243 if !ok { 244 err = fmt.Errorf("VALUES functions may only contain column names") 245 return false 246 } 247 valuesExpr.Value = insertRow[getField.Index()] 248 return false 249 }) 250 if err != nil { 251 return err 252 } 253 } 254 return nil 255 } 256 257 func (i *insertIter) Close(ctx *sql.Context) error { 258 if !i.closed { 259 i.closed = true 260 var rsErr, iErr, rErr, uErr error 261 if i.rowSource != nil { 262 rsErr = i.rowSource.Close(ctx) 263 } 264 if i.inserter != nil { 265 iErr = i.inserter.Close(ctx) 266 } 267 if i.replacer != nil { 268 rErr = i.replacer.Close(ctx) 269 } 270 if i.updater != nil { 271 uErr = i.updater.Close(ctx) 272 } 273 if rsErr != nil { 274 return rsErr 275 } 276 if iErr != nil { 277 return iErr 278 } 279 if rErr != nil { 280 return rErr 281 } 282 if uErr != nil { 283 return uErr 284 } 285 } 286 return nil 287 } 288 289 func (i *insertIter) updateLastInsertId(ctx *sql.Context, row sql.Row) { 290 if i.lastInsertIdUpdated { 291 return 292 } 293 294 autoIncVal := i.getAutoIncVal(row) 295 296 if i.hasAutoAutoIncValue { 297 ctx.SetLastQueryInfo(sql.LastInsertId, autoIncVal) 298 i.lastInsertIdUpdated = true 299 } 300 } 301 302 func (i *insertIter) getAutoIncVal(row sql.Row) int64 { 303 var autoIncVal int64 304 for i, expr := range i.insertExprs { 305 if _, ok := expr.(*expression.AutoIncrement); ok { 306 autoIncVal = toInt64(row[i]) 307 break 308 } 309 } 310 return autoIncVal 311 } 312 313 func (i *insertIter) ignoreOrClose(ctx *sql.Context, row sql.Row, err error) error { 314 if !i.ignore { 315 return sql.NewWrappedInsertError(row, err) 316 } 317 318 return warnOnIgnorableError(ctx, row, err) 319 } 320 321 // convertDataAndWarn modifies a row with data conversion issues in INSERT/UPDATE IGNORE calls 322 // Per MySQL docs "Rows set to values that would cause data conversion errors are set to the closest valid values instead" 323 // cc. https://dev.mysql.com/doc/refman/8.0/en/sql-mode.html#sql-mode-strict 324 func convertDataAndWarn(ctx *sql.Context, tableSchema sql.Schema, row sql.Row, columnIdx int, err error) sql.Row { 325 if types.ErrLengthBeyondLimit.Is(err) { 326 maxLength := tableSchema[columnIdx].Type.(sql.StringType).MaxCharacterLength() 327 row[columnIdx] = row[columnIdx].(string)[:maxLength] // truncate string 328 } else { 329 row[columnIdx] = tableSchema[columnIdx].Type.Zero() 330 } 331 332 sqlerr := sql.CastSQLError(err) 333 334 // Add a warning instead 335 if ctx != nil && ctx.Session != nil { 336 ctx.Session.Warn(&sql.Warning{ 337 Level: "Note", 338 Code: sqlerr.Num, 339 Message: err.Error(), 340 }) 341 } 342 343 return row 344 } 345 346 func warnOnIgnorableError(ctx *sql.Context, row sql.Row, err error) error { 347 // Check that this error is a part of the list of Ignorable Errors and create the relevant warning 348 for _, ie := range plan.IgnorableErrors { 349 if ie.Is(err) { 350 sqlerr := sql.CastSQLError(err) 351 352 // Add a warning instead 353 if ctx != nil && ctx.Session != nil { 354 ctx.Session.Warn(&sql.Warning{ 355 Level: "Note", 356 Code: sqlerr.Num, 357 Message: err.Error(), 358 }) 359 } 360 361 // In this case the default value gets updated so return nil 362 if sql.ErrInsertIntoNonNullableDefaultNullColumn.Is(err) { 363 return nil 364 } 365 366 // Return the InsertIgnore err to ensure our accumulator doesn't count this row. 367 return sql.NewIgnorableError(row) 368 } 369 } 370 371 return err 372 } 373 374 func (i *insertIter) evaluateChecks(ctx *sql.Context, row sql.Row) error { 375 for _, check := range i.checks { 376 if !check.Enforced { 377 continue 378 } 379 380 res, err := sql.EvaluateCondition(ctx, check.Expr, row) 381 382 if err != nil { 383 return err 384 } 385 386 if sql.IsFalse(res) { 387 return sql.ErrCheckConstraintViolated.New(check.Name) 388 } 389 } 390 391 return nil 392 } 393 394 func (i *insertIter) validateNullability(ctx *sql.Context, dstSchema sql.Schema, row sql.Row) error { 395 for count, col := range dstSchema { 396 if !col.Nullable && row[count] == nil { 397 // In the case of an IGNORE we set the nil value to a default and add a warning 398 if i.ignore { 399 row[count] = col.Type.Zero() 400 _ = warnOnIgnorableError(ctx, row, sql.ErrInsertIntoNonNullableProvidedNull.New(col.Name)) // will always return nil 401 } else { 402 return sql.ErrInsertIntoNonNullableProvidedNull.New(col.Name) 403 } 404 } 405 } 406 return nil 407 } 408 409 func toInt64(x interface{}) int64 { 410 switch x := x.(type) { 411 case int: 412 return int64(x) 413 case uint: 414 return int64(x) 415 case int8: 416 return int64(x) 417 case uint8: 418 return int64(x) 419 case int16: 420 return int64(x) 421 case uint16: 422 return int64(x) 423 case int32: 424 return int64(x) 425 case uint32: 426 return int64(x) 427 case int64: 428 return x 429 case uint64: 430 return int64(x) 431 case float32: 432 return int64(x) 433 case float64: 434 return int64(x) 435 default: 436 panic(fmt.Sprintf("Expected a numeric auto increment value, but got %T", x)) 437 } 438 }