github.com/XiaoMi/Gaea@v1.2.5/proxy/plan/plan_update.go (about) 1 // Copyright 2019 The Gaea Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package plan 16 17 import ( 18 "fmt" 19 20 "github.com/XiaoMi/Gaea/mysql" 21 "github.com/XiaoMi/Gaea/parser/ast" 22 "github.com/XiaoMi/Gaea/proxy/router" 23 "github.com/XiaoMi/Gaea/util" 24 ) 25 26 // UpdatePlan is the plan for update statement 27 type UpdatePlan struct { 28 basePlan 29 *TableAliasStmtInfo 30 31 stmt *ast.UpdateStmt 32 sqls map[string]map[string][]string 33 } 34 35 // NewUpdatePlan constructor of UpdatePlan 36 func NewUpdatePlan(stmt *ast.UpdateStmt, db, sql string, r *router.Router) *UpdatePlan { 37 return &UpdatePlan{ 38 TableAliasStmtInfo: NewTableAliasStmtInfo(db, sql, r), 39 stmt: stmt, 40 } 41 } 42 43 // ExecuteIn implement Plan 44 func (s *UpdatePlan) ExecuteIn(reqCtx *util.RequestContext, sess Executor) (*mysql.Result, error) { 45 sqls := s.sqls 46 if sqls == nil { 47 return nil, fmt.Errorf("SQL has not generated") 48 } 49 50 if len(sqls) == 0 { 51 return nil, nil 52 } 53 54 rs, err := sess.ExecuteSQLs(reqCtx, sqls) 55 if err != nil { 56 return nil, fmt.Errorf("execute in UpdatePlan error: %v", err) 57 } 58 59 r, err := MergeExecResult(rs) 60 61 if err != nil { 62 return nil, fmt.Errorf("merge update result error: %v", err) 63 } 64 65 return r, nil 66 } 67 68 // HandleUpdatePlan build a UpdatePlan 69 func HandleUpdatePlan(p *UpdatePlan) error { 70 if err := handleUpdateTableRefs(p); err != nil { 71 return fmt.Errorf("handle From error: %v", err) 72 } 73 74 if err := handleUpdateAssignmentList(p); err != nil { 75 return fmt.Errorf("handle assignment list error: %v", err) 76 } 77 78 if err := handleUpdateWhere(p); err != nil { 79 return fmt.Errorf("handle Where error: %v", err) 80 } 81 82 if err := handleUpdateOrderBy(p); err != nil { 83 return fmt.Errorf("handle OrderBy error: %v", err) 84 } 85 86 // Limit clause does not need to handle 87 88 // handle global table 89 if err := postHandleGlobalTableRouteResultInModify(p.StmtInfo); err != nil { 90 return fmt.Errorf("post handle global table error: %v", err) 91 } 92 93 sqls, err := generateShardingSQLs(p.stmt, p.GetRouteResult(), p.router) 94 if err != nil { 95 return fmt.Errorf("generate sqls error: %v", err) 96 } 97 98 p.sqls = sqls 99 return nil 100 } 101 102 func handleUpdateTableRefs(p *UpdatePlan) error { 103 tableRefs := p.stmt.TableRefs 104 if tableRefs == nil { 105 return nil 106 } 107 108 join := tableRefs.TableRefs 109 if join == nil { 110 return nil 111 } 112 113 if join.Right != nil { 114 return fmt.Errorf("does not support update multiple tables in sharding") 115 } 116 117 return handleJoin(p.TableAliasStmtInfo, join) 118 } 119 120 func handleUpdateWhere(p *UpdatePlan) error { 121 stmt := p.stmt 122 if stmt.Where == nil { 123 return nil 124 } 125 126 has, result, decorator, err := handleComparisonExpr(p.TableAliasStmtInfo, stmt.Where) 127 if err != nil { 128 return fmt.Errorf("rewrite Where error: %v", err) 129 } 130 if has { 131 p.GetRouteResult().Inter(result) 132 } 133 stmt.Where = decorator 134 return nil 135 } 136 137 func handleUpdateOrderBy(p *UpdatePlan) error { 138 order := p.stmt.Order 139 if order == nil { 140 return nil 141 } 142 143 for _, item := range order.Items { 144 columnExpr, ok := item.Expr.(*ast.ColumnNameExpr) 145 if !ok { 146 return fmt.Errorf("ByItem.Expr is not a ColumnNameExpr") 147 } 148 149 rule, need, isAlias, err := NeedCreateColumnNameExprDecoratorInField(p.TableAliasStmtInfo, columnExpr) 150 if err != nil { 151 return err 152 } 153 154 if need { 155 decorator := CreateColumnNameExprDecorator(columnExpr, rule, isAlias, p.GetRouteResult()) 156 item.Expr = decorator 157 } 158 } 159 160 return nil 161 } 162 163 // TODO: Assignment直接引用ColumnName, 不能做表名的装饰器. 采用的解决办法是UPDATE只支持一个表, 然后把DB名和表名去掉. 164 func handleUpdateAssignmentList(p *UpdatePlan) error { 165 l := p.stmt.List 166 for _, assignment := range l { 167 r, need, _, err := needCreateColumnNameDecorator(p.TableAliasStmtInfo, assignment.Column) 168 if err != nil { 169 return err 170 } 171 172 if need && r.GetShardingColumn() == assignment.Column.Name.L { 173 return fmt.Errorf("cannot update shard column value") 174 } 175 removeSchemaAndTableInfoInColumnName(assignment.Column) 176 } 177 return nil 178 }