github.com/dolthub/go-mysql-server@v0.18.0/sql/analyzer/assign_update_join.go (about) 1 package analyzer 2 3 import ( 4 "github.com/dolthub/go-mysql-server/sql" 5 "github.com/dolthub/go-mysql-server/sql/expression" 6 "github.com/dolthub/go-mysql-server/sql/plan" 7 "github.com/dolthub/go-mysql-server/sql/transform" 8 ) 9 10 // modifyUpdateExpressionsForJoin searches for a JOIN for UPDATE query and updates the child of the original update 11 // node to use a plan.UpdateJoin node as a child. 12 func modifyUpdateExpressionsForJoin(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) { 13 switch n := n.(type) { 14 case *plan.Update: 15 us, ok := n.Child.(*plan.UpdateSource) 16 if !ok { 17 return n, transform.SameTree, nil 18 } 19 20 var jn sql.Node 21 transform.Inspect(us, func(node sql.Node) bool { 22 switch node.(type) { 23 case *plan.JoinNode: 24 jn = node 25 return false 26 default: 27 return true 28 } 29 }) 30 31 if jn == nil { 32 return n, transform.SameTree, nil 33 } 34 35 updaters, err := rowUpdatersByTable(ctx, us, jn) 36 if err != nil { 37 return nil, transform.SameTree, err 38 } 39 40 uj := plan.NewUpdateJoin(updaters, us) 41 ret, err := n.WithChildren(uj) 42 if err != nil { 43 return nil, transform.SameTree, err 44 } 45 46 return ret, transform.NewTree, nil 47 } 48 49 return n, transform.SameTree, nil 50 } 51 52 // rowUpdatersByTable maps a set of tables to their RowUpdater objects. 53 func rowUpdatersByTable(ctx *sql.Context, node sql.Node, ij sql.Node) (map[string]sql.RowUpdater, error) { 54 namesOfTableToBeUpdated := getTablesToBeUpdated(node) 55 resolvedTables := getTablesByName(ij) 56 57 rowUpdatersByTable := make(map[string]sql.RowUpdater) 58 for tableToBeUpdated, _ := range namesOfTableToBeUpdated { 59 resolvedTable, ok := resolvedTables[tableToBeUpdated] 60 if !ok { 61 return nil, plan.ErrUpdateForTableNotSupported.New(tableToBeUpdated) 62 } 63 64 var table = resolvedTable.UnderlyingTable() 65 66 // If there is no UpdatableTable for a table being updated, error out 67 updatable, ok := table.(sql.UpdatableTable) 68 if !ok && updatable == nil { 69 return nil, plan.ErrUpdateForTableNotSupported.New(tableToBeUpdated) 70 } 71 72 keyless := sql.IsKeyless(updatable.Schema()) 73 if keyless { 74 return nil, sql.ErrUnsupportedFeature.New("error: keyless tables unsupported for UPDATE JOIN") 75 } 76 77 rowUpdatersByTable[tableToBeUpdated] = updatable.Updater(ctx) 78 } 79 80 return rowUpdatersByTable, nil 81 } 82 83 // getTablesToBeUpdated takes a node and looks for the tables to modified by a SetField. 84 func getTablesToBeUpdated(node sql.Node) map[string]struct{} { 85 ret := make(map[string]struct{}) 86 87 transform.InspectExpressions(node, func(e sql.Expression) bool { 88 switch e := e.(type) { 89 case *expression.SetField: 90 gf := e.LeftChild.(*expression.GetField) 91 ret[gf.Table()] = struct{}{} 92 return false 93 } 94 95 return true 96 }) 97 98 return ret 99 }