github.com/dolthub/go-mysql-server@v0.18.0/sql/analyzer/apply_update_accumulators.go (about) 1 // Copyright 2020-2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package analyzer 16 17 import ( 18 "fmt" 19 20 "github.com/dolthub/go-mysql-server/sql/transform" 21 22 "github.com/dolthub/go-mysql-server/sql" 23 "github.com/dolthub/go-mysql-server/sql/plan" 24 ) 25 26 // applyUpdateAccumulators wraps any Insert, Update, or Delete nodes with RowUpdateAccumulators to tally the results 27 // for report to the client. 28 func applyUpdateAccumulators(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) { 29 switch n := n.(type) { 30 case *plan.TriggerExecutor, *plan.InsertInto, *plan.DeleteFrom, *plan.Update: 31 accumulatorType, err := getUpdateAccumulatorType(n) 32 if err != nil { 33 return nil, transform.SameTree, err 34 } 35 return plan.NewRowUpdateAccumulator(n, accumulatorType), transform.NewTree, nil 36 default: 37 return n, transform.SameTree, nil 38 } 39 } 40 41 // getUpdateAccumulatorType returns the type of accumulator needed for the node given, or an error if there's no match. 42 func getUpdateAccumulatorType(n sql.Node) (plan.RowUpdateType, error) { 43 switch n := n.(type) { 44 case *plan.TriggerExecutor: 45 return getUpdateAccumulatorType(n.Left()) 46 case *plan.InsertInto: 47 if n.IsReplace { 48 return plan.UpdateTypeReplace, nil 49 } else if len(n.OnDupExprs) > 0 { 50 return plan.UpdateTypeDuplicateKeyUpdate, nil 51 } 52 return plan.UpdateTypeInsert, nil 53 case *plan.DeleteFrom: 54 return plan.UpdateTypeDelete, nil 55 case *plan.Update: 56 // search for a join 57 hasJoin := false 58 transform.Inspect(n, func(node sql.Node) bool { 59 switch node.(type) { 60 case *plan.JoinNode: 61 hasJoin = true 62 return false 63 } 64 65 return true 66 }) 67 68 if hasJoin { 69 return plan.UpdateTypeJoinUpdate, nil 70 } 71 72 return plan.UpdateTypeUpdate, nil 73 } 74 75 return -1, fmt.Errorf("unexpected node type: %T", n) 76 }