github.com/dolthub/go-mysql-server@v0.18.0/sql/rowexec/update.go (about) 1 // Copyright 2023 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package rowexec 16 17 import ( 18 "errors" 19 "fmt" 20 21 "github.com/dolthub/go-mysql-server/sql" 22 "github.com/dolthub/go-mysql-server/sql/plan" 23 ) 24 25 type updateIter struct { 26 childIter sql.RowIter 27 schema sql.Schema 28 updater sql.RowUpdater 29 checks sql.CheckConstraints 30 closed bool 31 ignore bool 32 } 33 34 func (u *updateIter) Next(ctx *sql.Context) (sql.Row, error) { 35 oldAndNewRow, err := u.childIter.Next(ctx) 36 if err != nil { 37 return nil, err 38 } 39 40 oldRow, newRow := oldAndNewRow[:len(oldAndNewRow)/2], oldAndNewRow[len(oldAndNewRow)/2:] 41 if equals, err := oldRow.Equals(newRow, u.schema); err == nil { 42 if !equals { 43 // apply check constraints 44 for _, check := range u.checks { 45 if !check.Enforced { 46 continue 47 } 48 49 res, err := sql.EvaluateCondition(ctx, check.Expr, newRow) 50 if err != nil { 51 return nil, err 52 } 53 54 if sql.IsFalse(res) { 55 return nil, u.ignoreOrError(ctx, newRow, sql.ErrCheckConstraintViolated.New(check.Name)) 56 } 57 } 58 59 err := u.validateNullability(ctx, newRow, u.schema) 60 if err != nil { 61 return nil, u.ignoreOrError(ctx, newRow, err) 62 } 63 64 err = u.updater.Update(ctx, oldRow, newRow) 65 if err != nil { 66 return nil, u.ignoreOrError(ctx, newRow, err) 67 } 68 } 69 } else { 70 return nil, err 71 } 72 73 return oldAndNewRow, nil 74 } 75 76 // Applies the update expressions given to the row given, returning the new resultant row. In the case that ignore is 77 // provided and there is a type conversion error, this function sets the value to the zero value as per the MySQL standard. 78 func applyUpdateExpressionsWithIgnore(ctx *sql.Context, updateExprs []sql.Expression, tableSchema sql.Schema, row sql.Row, ignore bool) (sql.Row, error) { 79 var secondPass []int 80 81 for i, updateExpr := range updateExprs { 82 defaultVal, isDefaultVal := defaultValFromSetExpression(updateExpr) 83 // Any generated columns must be projected into place so that the caller gets their newest values as well. We 84 // do this in a second pass as necessary. 85 if isDefaultVal && !defaultVal.IsLiteral() { 86 secondPass = append(secondPass, i) 87 continue 88 } 89 90 val, err := updateExpr.Eval(ctx, row) 91 if err != nil { 92 var wtce sql.WrappedTypeConversionError 93 isTypeConversionError := errors.As(err, &wtce) 94 if !isTypeConversionError || !ignore { 95 return nil, err 96 } 97 98 cpy := row.Copy() 99 cpy[wtce.OffendingIdx] = wtce.OffendingVal // Needed for strings 100 val = convertDataAndWarn(ctx, tableSchema, cpy, wtce.OffendingIdx, wtce.Err) 101 } 102 var ok bool 103 row, ok = val.(sql.Row) 104 if !ok { 105 return nil, plan.ErrUpdateUnexpectedSetResult.New(val) 106 } 107 } 108 109 for _, index := range secondPass { 110 val, err := updateExprs[index].Eval(ctx, row) 111 if err != nil { 112 return nil, err 113 } 114 115 var ok bool 116 row, ok = val.(sql.Row) 117 if !ok { 118 return nil, plan.ErrUpdateUnexpectedSetResult.New(val) 119 } 120 } 121 122 return row, nil 123 } 124 125 func (u *updateIter) validateNullability(ctx *sql.Context, row sql.Row, schema sql.Schema) error { 126 for idx := 0; idx < len(row); idx++ { 127 col := schema[idx] 128 if !col.Nullable && row[idx] == nil { 129 // In the case of an IGNORE we set the nil value to a default and add a warning 130 if u.ignore { 131 row[idx] = col.Type.Zero() 132 _ = warnOnIgnorableError(ctx, row, sql.ErrInsertIntoNonNullableProvidedNull.New(col.Name)) // will always return nil 133 } else { 134 return sql.ErrInsertIntoNonNullableProvidedNull.New(col.Name) 135 } 136 137 } 138 } 139 return nil 140 } 141 142 func (u *updateIter) Close(ctx *sql.Context) error { 143 if !u.closed { 144 u.closed = true 145 if err := u.updater.Close(ctx); err != nil { 146 return err 147 } 148 return u.childIter.Close(ctx) 149 } 150 return nil 151 } 152 153 func (u *updateIter) ignoreOrError(ctx *sql.Context, row sql.Row, err error) error { 154 if !u.ignore { 155 return err 156 } 157 158 return warnOnIgnorableError(ctx, row, err) 159 } 160 161 func newUpdateIter( 162 childIter sql.RowIter, 163 schema sql.Schema, 164 updater sql.RowUpdater, 165 checks sql.CheckConstraints, 166 ignore bool, 167 ) sql.RowIter { 168 if ignore { 169 return plan.NewCheckpointingTableEditorIter(&updateIter{ 170 childIter: childIter, 171 updater: updater, 172 schema: schema, 173 checks: checks, 174 ignore: true, 175 }, updater) 176 } else { 177 return plan.NewTableEditorIter(&updateIter{ 178 childIter: childIter, 179 updater: updater, 180 schema: schema, 181 checks: checks, 182 }, updater) 183 } 184 } 185 186 // updateJoinIter wraps the child UpdateSource projectIter and returns join row in such a way that updates per table row are 187 // done once. 188 type updateJoinIter struct { 189 updateSourceIter sql.RowIter 190 joinSchema sql.Schema 191 updaters map[string]sql.RowUpdater 192 caches map[string]sql.KeyValueCache 193 disposals map[string]sql.DisposeFunc 194 joinNode sql.Node 195 } 196 197 var _ sql.RowIter = (*updateJoinIter)(nil) 198 199 func (u *updateJoinIter) Next(ctx *sql.Context) (sql.Row, error) { 200 for { 201 oldAndNewRow, err := u.updateSourceIter.Next(ctx) 202 if err != nil { 203 return nil, err 204 } 205 206 oldJoinRow, newJoinRow := oldAndNewRow[:len(oldAndNewRow)/2], oldAndNewRow[len(oldAndNewRow)/2:] 207 208 tableToOldRowMap := plan.SplitRowIntoTableRowMap(oldJoinRow, u.joinSchema) 209 tableToNewRowMap := plan.SplitRowIntoTableRowMap(newJoinRow, u.joinSchema) 210 211 for tableName, _ := range u.updaters { 212 oldTableRow := tableToOldRowMap[tableName] 213 214 // Handle the case of row being ignored due to it not being valid in the join row. 215 if isRightOrLeftJoin(u.joinNode) { 216 works, err := u.shouldUpdateDirectionalJoin(ctx, oldJoinRow, oldTableRow) 217 if err != nil { 218 return nil, err 219 } 220 221 if !works { 222 // rewrite the newJoinRow to ensure an update does not happen 223 tableToNewRowMap[tableName] = oldTableRow 224 continue 225 } 226 } 227 228 // Determine whether this row in the table has already been updated 229 cache := u.getOrCreateCache(ctx, tableName) 230 hash, err := sql.HashOf(oldTableRow) 231 if err != nil { 232 return nil, err 233 } 234 235 _, err = cache.Get(hash) 236 if errors.Is(err, sql.ErrKeyNotFound) { 237 cache.Put(hash, struct{}{}) 238 continue 239 } else if err != nil { 240 return nil, err 241 } 242 243 // If this row for the table has already been updated we rewrite the newJoinRow counterpart to ensure that this 244 // returned row is not incorrectly counted by the update accumulator. 245 tableToNewRowMap[tableName] = oldTableRow 246 } 247 248 newJoinRow = recreateRowFromMap(tableToNewRowMap, u.joinSchema) 249 equals, err := oldJoinRow.Equals(newJoinRow, u.joinSchema) 250 if err != nil { 251 return nil, err 252 } 253 if !equals { 254 return append(oldJoinRow, newJoinRow...), nil 255 } 256 } 257 } 258 259 func toJoinNode(node sql.Node) *plan.JoinNode { 260 switch n := node.(type) { 261 case *plan.JoinNode: 262 return n 263 case *plan.TopN: 264 return toJoinNode(n.Child) 265 case *plan.Filter: 266 return toJoinNode(n.Child) 267 case *plan.Project: 268 return toJoinNode(n.Child) 269 case *plan.Limit: 270 return toJoinNode(n.Child) 271 case *plan.Offset: 272 return toJoinNode(n.Child) 273 case *plan.Sort: 274 return toJoinNode(n.Child) 275 case *plan.Distinct: 276 return toJoinNode(n.Child) 277 case *plan.Having: 278 return toJoinNode(n.Child) 279 case *plan.Window: 280 return toJoinNode(n.Child) 281 default: 282 return nil 283 } 284 } 285 286 func isIndexedAccess(node sql.Node) bool { 287 switch n := node.(type) { 288 case *plan.Filter: 289 return isIndexedAccess(n.Child) 290 case *plan.TableAlias: 291 return isIndexedAccess(n.Child) 292 case *plan.JoinNode: 293 return isIndexedAccess(n.Left()) 294 case *plan.IndexedTableAccess: 295 return true 296 } 297 return false 298 } 299 300 func isRightOrLeftJoin(node sql.Node) bool { 301 jn := toJoinNode(node) 302 if jn == nil { 303 return false 304 } 305 return jn.JoinType().IsLeftOuter() 306 } 307 308 // shouldUpdateDirectionalJoin determines whether a table row should be updated in the context of a large right/left join row. 309 // A table row should only be updated if 1) It fits the join conditions (the intersection of the join) 2) It fits only 310 // the left or right side of the join (given the direction). A row of all nils that does not pass condition 1 must not 311 // be part of the update operation. This is follows the logic as established in the joinIter. 312 func (u *updateJoinIter) shouldUpdateDirectionalJoin(ctx *sql.Context, joinRow, tableRow sql.Row) (bool, error) { 313 jn := toJoinNode(u.joinNode) 314 if jn == nil || !jn.JoinType().IsLeftOuter() { 315 return true, fmt.Errorf("expected left join") 316 } 317 318 // If the overall row fits the join condition it is fine (i.e. middle of the venn diagram). 319 val, err := jn.JoinCond().Eval(ctx, joinRow) 320 if err != nil { 321 return true, err 322 } 323 if v, ok := val.(bool); ok && v && !isIndexedAccess(jn) { 324 return true, nil 325 } 326 327 for _, v := range tableRow { 328 if v != nil { 329 return true, nil 330 } 331 } 332 333 // If the row is all nils we know it should not be updated as per the function description. 334 return false, nil 335 } 336 337 func (u *updateJoinIter) Close(context *sql.Context) error { 338 for _, disposeF := range u.disposals { 339 disposeF() 340 } 341 342 return u.updateSourceIter.Close(context) 343 } 344 345 func (u *updateJoinIter) getOrCreateCache(ctx *sql.Context, tableName string) sql.KeyValueCache { 346 potential, exists := u.caches[tableName] 347 if exists { 348 return potential 349 } 350 351 cache, disposal := ctx.Memory.NewHistoryCache() 352 u.caches[tableName] = cache 353 u.disposals[tableName] = disposal 354 355 return cache 356 } 357 358 // recreateRowFromMap takes a join schema and row map and recreates the original join row. 359 func recreateRowFromMap(rowMap map[string]sql.Row, joinSchema sql.Schema) sql.Row { 360 var ret sql.Row 361 362 if len(joinSchema) == 0 { 363 return ret 364 } 365 366 currentTable := joinSchema[0].Source 367 ret = append(ret, rowMap[currentTable]...) 368 369 for i := 1; i < len(joinSchema); i++ { 370 c := joinSchema[i] 371 372 if c.Source != currentTable { 373 ret = append(ret, rowMap[c.Source]...) 374 currentTable = c.Source 375 } 376 } 377 378 return ret 379 }