github.com/XiaoMi/Gaea@v1.2.5/proxy/plan/plan_unshard.go (about) 1 // Copyright 2019 The Gaea Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package plan 16 17 import ( 18 "fmt" 19 "strings" 20 21 "github.com/XiaoMi/Gaea/mysql" 22 "github.com/XiaoMi/Gaea/parser/ast" 23 "github.com/XiaoMi/Gaea/parser/format" 24 "github.com/XiaoMi/Gaea/util" 25 ) 26 27 // UnshardPlan is the plan for unshard statement 28 type UnshardPlan struct { 29 basePlan 30 31 db string 32 phyDBs map[string]string 33 sql string 34 stmt ast.StmtNode 35 } 36 37 // SelectLastInsertIDPlan is the plan for SELECT LAST_INSERT_ID() 38 // TODO: fix below 39 // https://dev.mysql.com/doc/refman/5.6/en/information-functions.html#function_last-insert-id 40 // The value of LAST_INSERT_ID() is not changed if you set the AUTO_INCREMENT column of a row 41 // to a non-“magic” value (that is, a value that is not NULL and not 0). 42 type SelectLastInsertIDPlan struct { 43 basePlan 44 } 45 46 // IsSelectLastInsertIDStmt check if the statement is SELECT LAST_INSERT_ID() 47 func IsSelectLastInsertIDStmt(stmt ast.StmtNode) bool { 48 s, ok := stmt.(*ast.SelectStmt) 49 if !ok { 50 return false 51 } 52 53 if len(s.Fields.Fields) != 1 { 54 return false 55 } 56 57 if s.From != nil || s.Where != nil || s.GroupBy != nil || s.Having != nil || s.OrderBy != nil || s.Limit != nil { 58 return false 59 } 60 61 f, ok := s.Fields.Fields[0].Expr.(*ast.FuncCallExpr) 62 if !ok { 63 return false 64 } 65 66 return f.FnName.L == "last_insert_id" 67 } 68 69 // CreateUnshardPlan constructor of UnshardPlan 70 func CreateUnshardPlan(stmt ast.StmtNode, phyDBs map[string]string, db string, tableNames []*ast.TableName) (*UnshardPlan, error) { 71 p := &UnshardPlan{ 72 db: db, 73 phyDBs: phyDBs, 74 stmt: stmt, 75 } 76 rewriteUnshardTableName(phyDBs, tableNames) 77 rsql, err := generateUnshardingSQL(stmt) 78 if err != nil { 79 return nil, fmt.Errorf("generate unshardPlan SQL error: %v", err) 80 } 81 p.sql = rsql 82 return p, nil 83 } 84 85 func rewriteUnshardTableName(phyDBs map[string]string, tableNames []*ast.TableName) { 86 for _, tableName := range tableNames { 87 if phyDB, ok := phyDBs[tableName.Schema.String()]; ok { 88 tableName.Schema.O = phyDB 89 tableName.Schema.L = strings.ToLower(phyDB) 90 } 91 } 92 } 93 94 func generateUnshardingSQL(stmt ast.StmtNode) (string, error) { 95 s := &strings.Builder{} 96 ctx := format.NewRestoreCtx(format.EscapeRestoreFlags, s) 97 _ = stmt.Restore(ctx) 98 return s.String(), nil 99 } 100 101 // CreateSelectLastInsertIDPlan constructor of SelectLastInsertIDPlan 102 func CreateSelectLastInsertIDPlan() *SelectLastInsertIDPlan { 103 return &SelectLastInsertIDPlan{} 104 } 105 106 // ExecuteIn implement Plan 107 func (p *UnshardPlan) ExecuteIn(reqCtx *util.RequestContext, se Executor) (*mysql.Result, error) { 108 r, err := se.ExecuteSQL(reqCtx, reqCtx.Get(util.DefaultSlice).(string), p.db, p.sql) 109 if err != nil { 110 return nil, err 111 } 112 113 // set last insert id to session 114 if _, ok := p.stmt.(*ast.InsertStmt); ok { 115 if r.InsertID != 0 { 116 se.SetLastInsertID(r.InsertID) 117 } 118 } 119 120 return r, nil 121 } 122 123 // ExecuteIn implement Plan 124 func (p *SelectLastInsertIDPlan) ExecuteIn(reqCtx *util.RequestContext, se Executor) (*mysql.Result, error) { 125 r := createLastInsertIDResult(se.GetLastInsertID()) 126 return r, nil 127 } 128 129 func createLastInsertIDResult(lastInsertID uint64) *mysql.Result { 130 name := "last_insert_id()" 131 var column = 1 132 var rows [][]string 133 var names = []string{ 134 name, 135 } 136 137 var t = fmt.Sprintf("%d", lastInsertID) 138 rows = append(rows, []string{t}) 139 140 r := new(mysql.Resultset) 141 142 var values = make([][]interface{}, len(rows)) 143 for i := range rows { 144 values[i] = make([]interface{}, column) 145 for j := range rows[i] { 146 values[i][j] = rows[i][j] 147 } 148 } 149 150 r, _ = mysql.BuildResultset(nil, names, values) 151 ret := &mysql.Result{ 152 Resultset: r, 153 } 154 155 return ret 156 }