github.com/XiaoMi/Gaea@v1.2.5/proxy/plan/plan_unshard.go (about)

     1  // Copyright 2019 The Gaea Authors. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package plan
    16  
    17  import (
    18  	"fmt"
    19  	"strings"
    20  
    21  	"github.com/XiaoMi/Gaea/mysql"
    22  	"github.com/XiaoMi/Gaea/parser/ast"
    23  	"github.com/XiaoMi/Gaea/parser/format"
    24  	"github.com/XiaoMi/Gaea/util"
    25  )
    26  
    27  // UnshardPlan is the plan for unshard statement
    28  type UnshardPlan struct {
    29  	basePlan
    30  
    31  	db     string
    32  	phyDBs map[string]string
    33  	sql    string
    34  	stmt   ast.StmtNode
    35  }
    36  
    37  // SelectLastInsertIDPlan is the plan for SELECT LAST_INSERT_ID()
    38  // TODO: fix below
    39  // https://dev.mysql.com/doc/refman/5.6/en/information-functions.html#function_last-insert-id
    40  // The value of LAST_INSERT_ID() is not changed if you set the AUTO_INCREMENT column of a row
    41  // to a non-“magic” value (that is, a value that is not NULL and not 0).
    42  type SelectLastInsertIDPlan struct {
    43  	basePlan
    44  }
    45  
    46  // IsSelectLastInsertIDStmt check if the statement is SELECT LAST_INSERT_ID()
    47  func IsSelectLastInsertIDStmt(stmt ast.StmtNode) bool {
    48  	s, ok := stmt.(*ast.SelectStmt)
    49  	if !ok {
    50  		return false
    51  	}
    52  
    53  	if len(s.Fields.Fields) != 1 {
    54  		return false
    55  	}
    56  
    57  	if s.From != nil || s.Where != nil || s.GroupBy != nil || s.Having != nil || s.OrderBy != nil || s.Limit != nil {
    58  		return false
    59  	}
    60  
    61  	f, ok := s.Fields.Fields[0].Expr.(*ast.FuncCallExpr)
    62  	if !ok {
    63  		return false
    64  	}
    65  
    66  	return f.FnName.L == "last_insert_id"
    67  }
    68  
    69  // CreateUnshardPlan constructor of UnshardPlan
    70  func CreateUnshardPlan(stmt ast.StmtNode, phyDBs map[string]string, db string, tableNames []*ast.TableName) (*UnshardPlan, error) {
    71  	p := &UnshardPlan{
    72  		db:     db,
    73  		phyDBs: phyDBs,
    74  		stmt:   stmt,
    75  	}
    76  	rewriteUnshardTableName(phyDBs, tableNames)
    77  	rsql, err := generateUnshardingSQL(stmt)
    78  	if err != nil {
    79  		return nil, fmt.Errorf("generate unshardPlan SQL error: %v", err)
    80  	}
    81  	p.sql = rsql
    82  	return p, nil
    83  }
    84  
    85  func rewriteUnshardTableName(phyDBs map[string]string, tableNames []*ast.TableName) {
    86  	for _, tableName := range tableNames {
    87  		if phyDB, ok := phyDBs[tableName.Schema.String()]; ok {
    88  			tableName.Schema.O = phyDB
    89  			tableName.Schema.L = strings.ToLower(phyDB)
    90  		}
    91  	}
    92  }
    93  
    94  func generateUnshardingSQL(stmt ast.StmtNode) (string, error) {
    95  	s := &strings.Builder{}
    96  	ctx := format.NewRestoreCtx(format.EscapeRestoreFlags, s)
    97  	_ = stmt.Restore(ctx)
    98  	return s.String(), nil
    99  }
   100  
   101  // CreateSelectLastInsertIDPlan constructor of SelectLastInsertIDPlan
   102  func CreateSelectLastInsertIDPlan() *SelectLastInsertIDPlan {
   103  	return &SelectLastInsertIDPlan{}
   104  }
   105  
   106  // ExecuteIn implement Plan
   107  func (p *UnshardPlan) ExecuteIn(reqCtx *util.RequestContext, se Executor) (*mysql.Result, error) {
   108  	r, err := se.ExecuteSQL(reqCtx, reqCtx.Get(util.DefaultSlice).(string), p.db, p.sql)
   109  	if err != nil {
   110  		return nil, err
   111  	}
   112  
   113  	// set last insert id to session
   114  	if _, ok := p.stmt.(*ast.InsertStmt); ok {
   115  		if r.InsertID != 0 {
   116  			se.SetLastInsertID(r.InsertID)
   117  		}
   118  	}
   119  
   120  	return r, nil
   121  }
   122  
   123  // ExecuteIn implement Plan
   124  func (p *SelectLastInsertIDPlan) ExecuteIn(reqCtx *util.RequestContext, se Executor) (*mysql.Result, error) {
   125  	r := createLastInsertIDResult(se.GetLastInsertID())
   126  	return r, nil
   127  }
   128  
   129  func createLastInsertIDResult(lastInsertID uint64) *mysql.Result {
   130  	name := "last_insert_id()"
   131  	var column = 1
   132  	var rows [][]string
   133  	var names = []string{
   134  		name,
   135  	}
   136  
   137  	var t = fmt.Sprintf("%d", lastInsertID)
   138  	rows = append(rows, []string{t})
   139  
   140  	r := new(mysql.Resultset)
   141  
   142  	var values = make([][]interface{}, len(rows))
   143  	for i := range rows {
   144  		values[i] = make([]interface{}, column)
   145  		for j := range rows[i] {
   146  			values[i][j] = rows[i][j]
   147  		}
   148  	}
   149  
   150  	r, _ = mysql.BuildResultset(nil, names, values)
   151  	ret := &mysql.Result{
   152  		Resultset: r,
   153  	}
   154  
   155  	return ret
   156  }