github.com/XiaoMi/Gaea@v1.2.5/proxy/plan/plan_insert.go (about)

     1  // Copyright 2019 The Gaea Authors. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package plan
    16  
    17  import (
    18  	"fmt"
    19  
    20  	"github.com/XiaoMi/Gaea/core/errors"
    21  	"github.com/XiaoMi/Gaea/log"
    22  	"github.com/XiaoMi/Gaea/mysql"
    23  	"github.com/XiaoMi/Gaea/parser/ast"
    24  	driver "github.com/XiaoMi/Gaea/parser/tidb-types/parser_driver"
    25  	"github.com/XiaoMi/Gaea/proxy/router"
    26  	"github.com/XiaoMi/Gaea/proxy/sequence"
    27  	"github.com/XiaoMi/Gaea/util"
    28  )
    29  
    30  // InsertPlan is the plan for insert statement
    31  type InsertPlan struct {
    32  	basePlan
    33  	*StmtInfo
    34  
    35  	stmt *ast.InsertStmt
    36  
    37  	table               string
    38  	isAssignmentMode    bool
    39  	shardingColumnIndex int
    40  
    41  	sequences *sequence.SequenceManager
    42  
    43  	sqls map[string]map[string][]string
    44  }
    45  
    46  // NewInsertPlan constructor of InsertPlan
    47  func NewInsertPlan(db string, sql string, r *router.Router, seq *sequence.SequenceManager) *InsertPlan {
    48  	return &InsertPlan{
    49  		StmtInfo:            NewStmtInfo(db, sql, r),
    50  		shardingColumnIndex: -1,
    51  		sequences:           seq,
    52  	}
    53  }
    54  
    55  // GetStmt return InsertStmt
    56  func (s *InsertPlan) GetStmt() *ast.InsertStmt {
    57  	return s.stmt
    58  }
    59  
    60  // HandleInsertStmt build a InsertPlan
    61  func HandleInsertStmt(p *InsertPlan, stmt *ast.InsertStmt) error {
    62  	p.stmt = stmt
    63  
    64  	if err := precheckInsertStmt(p); err != nil {
    65  		return err
    66  	}
    67  
    68  	// 处理全局表成功时会触发fastReturn
    69  	fastReturn, err := handleInsertTableRefs(p)
    70  	if err != nil {
    71  		return fmt.Errorf("handleInsertTableRefs error: %v", err)
    72  	}
    73  	if fastReturn {
    74  		return nil
    75  	}
    76  
    77  	if err := handleInsertGlobalSequenceValue(p); err != nil {
    78  		return fmt.Errorf("handleInsertGlobalSequenceValue error: %v", err)
    79  	}
    80  
    81  	if err := handleInsertColumnNames(p); err != nil {
    82  		return fmt.Errorf("handleInsertColumnNames error: %v", err)
    83  	}
    84  
    85  	if err := handleInsertOnDuplicate(p); err != nil {
    86  		return fmt.Errorf("handleInsertOnDuplicate error: %v", err)
    87  	}
    88  
    89  	if err := handleInsertValues(p); err != nil {
    90  		return fmt.Errorf("handleInsertValues error: %v", err)
    91  	}
    92  
    93  	sqls, err := generateShardingSQLs(p.stmt, p.result, p.router)
    94  	if err != nil {
    95  		log.Warn("generate insert sql failed, %v", err)
    96  		return err
    97  	}
    98  
    99  	p.sqls = sqls
   100  
   101  	return nil
   102  }
   103  
   104  func precheckInsertStmt(p *InsertPlan) error {
   105  	stmt := p.stmt
   106  	// doesn't support insert into select...
   107  	if stmt.Select != nil {
   108  		return errors.ErrSelectInInsert
   109  	}
   110  
   111  	// INSERT INTO tbl SET col=val, ...
   112  	if len(stmt.Setlist) != 0 {
   113  		p.isAssignmentMode = true
   114  		return nil
   115  	}
   116  
   117  	if len(stmt.Columns) == 0 {
   118  		return errors.ErrIRNoColumns
   119  	}
   120  
   121  	values := stmt.Lists[0]
   122  	if len(stmt.Columns) != len(values) {
   123  		return fmt.Errorf("column count doesn't match value count")
   124  	}
   125  
   126  	return nil
   127  }
   128  
   129  func handleInsertTableRefs(p *InsertPlan) (fastReturn bool, err error) {
   130  	if p.stmt.Table.TableRefs.Right != nil {
   131  		return false, fmt.Errorf("have multi tables in insert")
   132  	}
   133  	tableSource, ok := p.stmt.Table.TableRefs.Left.(*ast.TableSource)
   134  	if !ok {
   135  		return false, fmt.Errorf("not a table source")
   136  	}
   137  	tableName := tableSource.Source.(*ast.TableName)
   138  	p.table = tableName.Name.L
   139  
   140  	rule, need, err := NeedCreateTableNameDecoratorWithoutAlias(p.StmtInfo, tableName)
   141  	if err != nil {
   142  		return false, fmt.Errorf("check table name need to decorate error: %v", err)
   143  	}
   144  
   145  	if !need {
   146  		// 如果不需要装饰, 不应该走到分表逻辑, 直接报错
   147  		return false, fmt.Errorf("not a sharding table")
   148  	}
   149  
   150  	decorator, err := CreateTableNameDecorator(tableName, rule, p.GetRouteResult())
   151  	if err != nil {
   152  		return false, fmt.Errorf("create table name decorator error: %v", err)
   153  	}
   154  
   155  	tableSource.Source = decorator
   156  
   157  	// 如果是全局表, 则将记录写入所有分片
   158  	if rule.GetType() == router.GlobalTableRuleType {
   159  		p.result.db = rule.GetDB()
   160  		p.result.table = rule.GetTable()
   161  		p.result.indexes = rule.GetSubTableIndexes()
   162  		sqls, err := generateShardingSQLs(p.stmt, p.result, p.router)
   163  		if err != nil {
   164  			return false, fmt.Errorf("generate global table insert sql error: %v", err)
   165  		}
   166  		p.sqls = sqls
   167  		return true, nil
   168  	}
   169  
   170  	return false, nil
   171  }
   172  
   173  func handleInsertColumnNames(p *InsertPlan) error {
   174  	if p.isAssignmentMode {
   175  		// INSERT INTO tbl SET col = val, ...
   176  		for i, assignment := range p.stmt.Setlist {
   177  			col := assignment.Column
   178  			removeSchemaAndTableInfoInColumnName(col)
   179  			columnName := col.Name.L
   180  			rule := p.tableRules[p.table]
   181  			if columnName == rule.GetShardingColumn() {
   182  				p.shardingColumnIndex = i
   183  			}
   184  		}
   185  	} else {
   186  		// INSERT INTO tbl (col, ...) VALUES (val, ...)
   187  		for i, col := range p.stmt.Columns {
   188  			removeSchemaAndTableInfoInColumnName(col)
   189  			columnName := col.Name.L
   190  			rule := p.tableRules[p.table]
   191  			if columnName == rule.GetShardingColumn() {
   192  				p.shardingColumnIndex = i
   193  			}
   194  		}
   195  	}
   196  	if p.shardingColumnIndex == -1 {
   197  		return fmt.Errorf("sharding column not found")
   198  	}
   199  	return nil
   200  }
   201  
   202  // 只有一个表, 直接去掉DB名和表名, 就不需要加装饰器了
   203  func removeSchemaAndTableInfoInColumnName(column *ast.ColumnName) {
   204  	column.Schema.O = ""
   205  	column.Schema.L = ""
   206  	column.Table.O = ""
   207  	column.Table.L = ""
   208  }
   209  
   210  // TODO: refactor
   211  func handleInsertValues(p *InsertPlan) error {
   212  	// assignment mode
   213  	if p.isAssignmentMode {
   214  		valueItem := p.stmt.Setlist[p.shardingColumnIndex].Expr
   215  		switch x := valueItem.(type) {
   216  		case *driver.ValueExpr:
   217  			v, err := util.GetValueExprResult(x)
   218  			if err != nil {
   219  				return fmt.Errorf("get value expr result failed, %v", err)
   220  			}
   221  			if v == nil {
   222  				return fmt.Errorf("sharding value cannot be null")
   223  			}
   224  			routeIdx, err := p.tableRules[p.table].FindTableIndex(v)
   225  			if err != nil {
   226  				return fmt.Errorf("find table index error: %v", err)
   227  			}
   228  			p.result.Inter([]int{routeIdx})
   229  		}
   230  		return nil
   231  	}
   232  
   233  	// not assignment mode
   234  	for _, valueList := range p.stmt.Lists {
   235  		valueItem := valueList[p.shardingColumnIndex]
   236  		switch x := valueItem.(type) {
   237  		case *driver.ValueExpr:
   238  			v, err := util.GetValueExprResult(x)
   239  			if err != nil {
   240  				return fmt.Errorf("get value expr result failed, %v", err)
   241  			}
   242  			if v == nil {
   243  				return fmt.Errorf("sharding value cannot be null")
   244  			}
   245  			routeIdx, err := p.tableRules[p.table].FindTableIndex(v)
   246  			if err != nil {
   247  				return fmt.Errorf("find table index error: %v", err)
   248  			}
   249  			p.result.Inter([]int{routeIdx})
   250  		}
   251  	}
   252  	if len(p.result.GetShardIndexes()) == 0 {
   253  		return fmt.Errorf("batch insert has cross slice values or no route found")
   254  	}
   255  	return nil
   256  }
   257  
   258  // check on duplicate key
   259  // 不管分片表的配置信息, 只要在OnDuplicate出现分片列, 就返回错误
   260  // 去掉ColumnName中的DB名和表名
   261  func handleInsertOnDuplicate(p *InsertPlan) error {
   262  	if p.stmt.OnDuplicate == nil {
   263  		return nil
   264  	}
   265  
   266  	shardingColumnName := p.tableRules[p.table].GetShardingColumn()
   267  	for _, a := range p.stmt.OnDuplicate {
   268  		if a.Column.Name.L == shardingColumnName {
   269  			return errors.ErrUpdateKey
   270  		}
   271  		removeSchemaAndTableInfoInColumnName(a.Column)
   272  	}
   273  
   274  	return nil
   275  }
   276  
   277  // 处理全局序列号, 目前一条SQL中只允许一个列使用全局序列号
   278  func handleInsertGlobalSequenceValue(p *InsertPlan) error {
   279  	seq, ok := p.sequences.GetSequence(p.db, p.table)
   280  	if !ok {
   281  		return nil
   282  	}
   283  	pkName := seq.GetPKName()
   284  
   285  	// not assignment mode
   286  	if p.isAssignmentMode {
   287  		for _, assignment := range p.stmt.Setlist {
   288  			columnName := assignment.Column.Name.L
   289  			if columnName == pkName {
   290  				if x, ok := assignment.Expr.(*ast.FuncCallExpr); ok {
   291  					if x.FnName.L == "nextval" {
   292  						id, err := seq.NextSeq()
   293  						if err != nil {
   294  							return fmt.Errorf("get next seq error: %v", err)
   295  						}
   296  						assignment.Expr = ast.NewValueExpr(id)
   297  						break
   298  					}
   299  				}
   300  			}
   301  		}
   302  		return nil
   303  	}
   304  
   305  	// not assignment mode
   306  	var seqIndex = -1
   307  	for i, column := range p.stmt.Columns {
   308  		columnName := column.Name.L
   309  		if columnName == pkName {
   310  			seqIndex = i
   311  			break
   312  		}
   313  	}
   314  
   315  	// global sequence column not found
   316  	if seqIndex == -1 {
   317  		return nil
   318  	}
   319  
   320  	for _, valueList := range p.stmt.Lists {
   321  		if x, ok := valueList[seqIndex].(*ast.FuncCallExpr); ok {
   322  			if x.FnName.L == "nextval" {
   323  				id, err := seq.NextSeq()
   324  				if err != nil {
   325  					return fmt.Errorf("get next seq error: %v", err)
   326  				}
   327  				valueList[seqIndex] = ast.NewValueExpr(id)
   328  			}
   329  		}
   330  	}
   331  
   332  	return nil
   333  }
   334  
   335  // ExecuteIn implement Plan
   336  func (s *InsertPlan) ExecuteIn(reqCtx *util.RequestContext, sess Executor) (*mysql.Result, error) {
   337  	rs, err := sess.ExecuteSQLs(reqCtx, s.sqls)
   338  	if err != nil {
   339  		return nil, fmt.Errorf("execute in InsertPlan error: %v", err)
   340  	}
   341  
   342  	r, err := MergeExecResult(rs)
   343  	if err != nil {
   344  		return nil, err
   345  	}
   346  
   347  	if r.InsertID != 0 {
   348  		sess.SetLastInsertID(r.InsertID)
   349  	}
   350  
   351  	return r, nil
   352  }