github.com/XiaoMi/Gaea@v1.2.5/proxy/plan/plan_insert.go (about) 1 // Copyright 2019 The Gaea Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package plan 16 17 import ( 18 "fmt" 19 20 "github.com/XiaoMi/Gaea/core/errors" 21 "github.com/XiaoMi/Gaea/log" 22 "github.com/XiaoMi/Gaea/mysql" 23 "github.com/XiaoMi/Gaea/parser/ast" 24 driver "github.com/XiaoMi/Gaea/parser/tidb-types/parser_driver" 25 "github.com/XiaoMi/Gaea/proxy/router" 26 "github.com/XiaoMi/Gaea/proxy/sequence" 27 "github.com/XiaoMi/Gaea/util" 28 ) 29 30 // InsertPlan is the plan for insert statement 31 type InsertPlan struct { 32 basePlan 33 *StmtInfo 34 35 stmt *ast.InsertStmt 36 37 table string 38 isAssignmentMode bool 39 shardingColumnIndex int 40 41 sequences *sequence.SequenceManager 42 43 sqls map[string]map[string][]string 44 } 45 46 // NewInsertPlan constructor of InsertPlan 47 func NewInsertPlan(db string, sql string, r *router.Router, seq *sequence.SequenceManager) *InsertPlan { 48 return &InsertPlan{ 49 StmtInfo: NewStmtInfo(db, sql, r), 50 shardingColumnIndex: -1, 51 sequences: seq, 52 } 53 } 54 55 // GetStmt return InsertStmt 56 func (s *InsertPlan) GetStmt() *ast.InsertStmt { 57 return s.stmt 58 } 59 60 // HandleInsertStmt build a InsertPlan 61 func HandleInsertStmt(p *InsertPlan, stmt *ast.InsertStmt) error { 62 p.stmt = stmt 63 64 if err := precheckInsertStmt(p); err != nil { 65 return err 66 } 67 68 // 处理全局表成功时会触发fastReturn 69 fastReturn, err := handleInsertTableRefs(p) 70 if err != nil { 71 return fmt.Errorf("handleInsertTableRefs error: %v", err) 72 } 73 if fastReturn { 74 return nil 75 } 76 77 if err := handleInsertGlobalSequenceValue(p); err != nil { 78 return fmt.Errorf("handleInsertGlobalSequenceValue error: %v", err) 79 } 80 81 if err := handleInsertColumnNames(p); err != nil { 82 return fmt.Errorf("handleInsertColumnNames error: %v", err) 83 } 84 85 if err := handleInsertOnDuplicate(p); err != nil { 86 return fmt.Errorf("handleInsertOnDuplicate error: %v", err) 87 } 88 89 if err := handleInsertValues(p); err != nil { 90 return fmt.Errorf("handleInsertValues error: %v", err) 91 } 92 93 sqls, err := generateShardingSQLs(p.stmt, p.result, p.router) 94 if err != nil { 95 log.Warn("generate insert sql failed, %v", err) 96 return err 97 } 98 99 p.sqls = sqls 100 101 return nil 102 } 103 104 func precheckInsertStmt(p *InsertPlan) error { 105 stmt := p.stmt 106 // doesn't support insert into select... 107 if stmt.Select != nil { 108 return errors.ErrSelectInInsert 109 } 110 111 // INSERT INTO tbl SET col=val, ... 112 if len(stmt.Setlist) != 0 { 113 p.isAssignmentMode = true 114 return nil 115 } 116 117 if len(stmt.Columns) == 0 { 118 return errors.ErrIRNoColumns 119 } 120 121 values := stmt.Lists[0] 122 if len(stmt.Columns) != len(values) { 123 return fmt.Errorf("column count doesn't match value count") 124 } 125 126 return nil 127 } 128 129 func handleInsertTableRefs(p *InsertPlan) (fastReturn bool, err error) { 130 if p.stmt.Table.TableRefs.Right != nil { 131 return false, fmt.Errorf("have multi tables in insert") 132 } 133 tableSource, ok := p.stmt.Table.TableRefs.Left.(*ast.TableSource) 134 if !ok { 135 return false, fmt.Errorf("not a table source") 136 } 137 tableName := tableSource.Source.(*ast.TableName) 138 p.table = tableName.Name.L 139 140 rule, need, err := NeedCreateTableNameDecoratorWithoutAlias(p.StmtInfo, tableName) 141 if err != nil { 142 return false, fmt.Errorf("check table name need to decorate error: %v", err) 143 } 144 145 if !need { 146 // 如果不需要装饰, 不应该走到分表逻辑, 直接报错 147 return false, fmt.Errorf("not a sharding table") 148 } 149 150 decorator, err := CreateTableNameDecorator(tableName, rule, p.GetRouteResult()) 151 if err != nil { 152 return false, fmt.Errorf("create table name decorator error: %v", err) 153 } 154 155 tableSource.Source = decorator 156 157 // 如果是全局表, 则将记录写入所有分片 158 if rule.GetType() == router.GlobalTableRuleType { 159 p.result.db = rule.GetDB() 160 p.result.table = rule.GetTable() 161 p.result.indexes = rule.GetSubTableIndexes() 162 sqls, err := generateShardingSQLs(p.stmt, p.result, p.router) 163 if err != nil { 164 return false, fmt.Errorf("generate global table insert sql error: %v", err) 165 } 166 p.sqls = sqls 167 return true, nil 168 } 169 170 return false, nil 171 } 172 173 func handleInsertColumnNames(p *InsertPlan) error { 174 if p.isAssignmentMode { 175 // INSERT INTO tbl SET col = val, ... 176 for i, assignment := range p.stmt.Setlist { 177 col := assignment.Column 178 removeSchemaAndTableInfoInColumnName(col) 179 columnName := col.Name.L 180 rule := p.tableRules[p.table] 181 if columnName == rule.GetShardingColumn() { 182 p.shardingColumnIndex = i 183 } 184 } 185 } else { 186 // INSERT INTO tbl (col, ...) VALUES (val, ...) 187 for i, col := range p.stmt.Columns { 188 removeSchemaAndTableInfoInColumnName(col) 189 columnName := col.Name.L 190 rule := p.tableRules[p.table] 191 if columnName == rule.GetShardingColumn() { 192 p.shardingColumnIndex = i 193 } 194 } 195 } 196 if p.shardingColumnIndex == -1 { 197 return fmt.Errorf("sharding column not found") 198 } 199 return nil 200 } 201 202 // 只有一个表, 直接去掉DB名和表名, 就不需要加装饰器了 203 func removeSchemaAndTableInfoInColumnName(column *ast.ColumnName) { 204 column.Schema.O = "" 205 column.Schema.L = "" 206 column.Table.O = "" 207 column.Table.L = "" 208 } 209 210 // TODO: refactor 211 func handleInsertValues(p *InsertPlan) error { 212 // assignment mode 213 if p.isAssignmentMode { 214 valueItem := p.stmt.Setlist[p.shardingColumnIndex].Expr 215 switch x := valueItem.(type) { 216 case *driver.ValueExpr: 217 v, err := util.GetValueExprResult(x) 218 if err != nil { 219 return fmt.Errorf("get value expr result failed, %v", err) 220 } 221 if v == nil { 222 return fmt.Errorf("sharding value cannot be null") 223 } 224 routeIdx, err := p.tableRules[p.table].FindTableIndex(v) 225 if err != nil { 226 return fmt.Errorf("find table index error: %v", err) 227 } 228 p.result.Inter([]int{routeIdx}) 229 } 230 return nil 231 } 232 233 // not assignment mode 234 for _, valueList := range p.stmt.Lists { 235 valueItem := valueList[p.shardingColumnIndex] 236 switch x := valueItem.(type) { 237 case *driver.ValueExpr: 238 v, err := util.GetValueExprResult(x) 239 if err != nil { 240 return fmt.Errorf("get value expr result failed, %v", err) 241 } 242 if v == nil { 243 return fmt.Errorf("sharding value cannot be null") 244 } 245 routeIdx, err := p.tableRules[p.table].FindTableIndex(v) 246 if err != nil { 247 return fmt.Errorf("find table index error: %v", err) 248 } 249 p.result.Inter([]int{routeIdx}) 250 } 251 } 252 if len(p.result.GetShardIndexes()) == 0 { 253 return fmt.Errorf("batch insert has cross slice values or no route found") 254 } 255 return nil 256 } 257 258 // check on duplicate key 259 // 不管分片表的配置信息, 只要在OnDuplicate出现分片列, 就返回错误 260 // 去掉ColumnName中的DB名和表名 261 func handleInsertOnDuplicate(p *InsertPlan) error { 262 if p.stmt.OnDuplicate == nil { 263 return nil 264 } 265 266 shardingColumnName := p.tableRules[p.table].GetShardingColumn() 267 for _, a := range p.stmt.OnDuplicate { 268 if a.Column.Name.L == shardingColumnName { 269 return errors.ErrUpdateKey 270 } 271 removeSchemaAndTableInfoInColumnName(a.Column) 272 } 273 274 return nil 275 } 276 277 // 处理全局序列号, 目前一条SQL中只允许一个列使用全局序列号 278 func handleInsertGlobalSequenceValue(p *InsertPlan) error { 279 seq, ok := p.sequences.GetSequence(p.db, p.table) 280 if !ok { 281 return nil 282 } 283 pkName := seq.GetPKName() 284 285 // not assignment mode 286 if p.isAssignmentMode { 287 for _, assignment := range p.stmt.Setlist { 288 columnName := assignment.Column.Name.L 289 if columnName == pkName { 290 if x, ok := assignment.Expr.(*ast.FuncCallExpr); ok { 291 if x.FnName.L == "nextval" { 292 id, err := seq.NextSeq() 293 if err != nil { 294 return fmt.Errorf("get next seq error: %v", err) 295 } 296 assignment.Expr = ast.NewValueExpr(id) 297 break 298 } 299 } 300 } 301 } 302 return nil 303 } 304 305 // not assignment mode 306 var seqIndex = -1 307 for i, column := range p.stmt.Columns { 308 columnName := column.Name.L 309 if columnName == pkName { 310 seqIndex = i 311 break 312 } 313 } 314 315 // global sequence column not found 316 if seqIndex == -1 { 317 return nil 318 } 319 320 for _, valueList := range p.stmt.Lists { 321 if x, ok := valueList[seqIndex].(*ast.FuncCallExpr); ok { 322 if x.FnName.L == "nextval" { 323 id, err := seq.NextSeq() 324 if err != nil { 325 return fmt.Errorf("get next seq error: %v", err) 326 } 327 valueList[seqIndex] = ast.NewValueExpr(id) 328 } 329 } 330 } 331 332 return nil 333 } 334 335 // ExecuteIn implement Plan 336 func (s *InsertPlan) ExecuteIn(reqCtx *util.RequestContext, sess Executor) (*mysql.Result, error) { 337 rs, err := sess.ExecuteSQLs(reqCtx, s.sqls) 338 if err != nil { 339 return nil, fmt.Errorf("execute in InsertPlan error: %v", err) 340 } 341 342 r, err := MergeExecResult(rs) 343 if err != nil { 344 return nil, err 345 } 346 347 if r.InsertID != 0 { 348 sess.SetLastInsertID(r.InsertID) 349 } 350 351 return r, nil 352 }