github.com/XiaoMi/Gaea@v1.2.5/proxy/plan/plan.go (about)

     1  // Copyright 2019 The Gaea Authors. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package plan
    16  
    17  import (
    18  	"fmt"
    19  	"strings"
    20  
    21  	"github.com/XiaoMi/Gaea/mysql"
    22  	"github.com/XiaoMi/Gaea/parser"
    23  	"github.com/XiaoMi/Gaea/parser/ast"
    24  	"github.com/XiaoMi/Gaea/parser/format"
    25  	"github.com/XiaoMi/Gaea/proxy/router"
    26  	"github.com/XiaoMi/Gaea/proxy/sequence"
    27  	"github.com/XiaoMi/Gaea/util"
    28  	"github.com/XiaoMi/Gaea/util/hack"
    29  )
    30  
    31  // type check
    32  var _ Plan = &UnshardPlan{}
    33  var _ Plan = &SelectPlan{}
    34  var _ Plan = &DeletePlan{}
    35  var _ Plan = &UpdatePlan{}
    36  var _ Plan = &InsertPlan{}
    37  var _ Plan = &SelectLastInsertIDPlan{}
    38  
    39  // Plan is a interface for select/insert etc.
    40  type Plan interface {
    41  	ExecuteIn(*util.RequestContext, Executor) (*mysql.Result, error)
    42  
    43  	// only for cache
    44  	Size() int
    45  }
    46  
    47  // Executor TODO: move to package executor
    48  type Executor interface {
    49  
    50  	// 执行分片或非分片单条SQL
    51  	ExecuteSQL(ctx *util.RequestContext, slice, db, sql string) (*mysql.Result, error)
    52  
    53  	// 执行分片SQL
    54  	ExecuteSQLs(*util.RequestContext, map[string]map[string][]string) ([]*mysql.Result, error)
    55  
    56  	// 用于执行INSERT时设置last insert id
    57  	SetLastInsertID(uint64)
    58  
    59  	GetLastInsertID() uint64
    60  }
    61  
    62  // Checker 用于检查SelectStmt是不是分表的Visitor, 以及是否包含DB信息
    63  type Checker struct {
    64  	db            string
    65  	router        *router.Router
    66  	hasShardTable bool // 是否包含分片表
    67  	dbInvalid     bool // SQL是否No database selected
    68  	tableNames    []*ast.TableName
    69  }
    70  
    71  // NewChecker db为USE db中设置的DB名. 如果没有执行USE db, 则为空字符串
    72  func NewChecker(db string, router *router.Router) *Checker {
    73  	return &Checker{
    74  		db:            db,
    75  		router:        router,
    76  		hasShardTable: false,
    77  		dbInvalid:     false,
    78  	}
    79  }
    80  
    81  func (s *Checker) GetUnshardTableNames() []*ast.TableName {
    82  	return s.tableNames
    83  }
    84  
    85  // IsDatabaseInvalid 判断执行计划中是否包含db信息, 如果不包含, 且又含有表名, 则是一个错的执行计划, 应该返回以下错误:
    86  // ERROR 1046 (3D000): No database selected
    87  func (s *Checker) IsDatabaseInvalid() bool {
    88  	return s.dbInvalid
    89  }
    90  
    91  // IsShard if is shard table
    92  func (s *Checker) IsShard() bool {
    93  	return s.hasShardTable
    94  }
    95  
    96  // Enter for node visit
    97  func (s *Checker) Enter(n ast.Node) (node ast.Node, skipChildren bool) {
    98  	if s.hasShardTable {
    99  		return n, true
   100  	}
   101  	switch nn := n.(type) {
   102  	case *ast.TableName:
   103  		if s.isTableNameDatabaseInvalid(nn) {
   104  			s.dbInvalid = true
   105  			return n, true
   106  		}
   107  		has := s.hasShardTableInTableName(nn)
   108  		if has {
   109  			s.hasShardTable = true
   110  			return n, true
   111  		}
   112  		s.tableNames = append(s.tableNames, nn)
   113  	}
   114  	return n, false
   115  }
   116  
   117  // Leave for node visit
   118  func (s *Checker) Leave(n ast.Node) (node ast.Node, ok bool) {
   119  	return n, !s.dbInvalid && !s.hasShardTable
   120  }
   121  
   122  // 如果ast.TableName不带DB名, 且Session未设置DB, 则是不允许的SQL, 应该返回No database selected
   123  func (s *Checker) isTableNameDatabaseInvalid(n *ast.TableName) bool {
   124  	return s.db == "" && n.Schema.L == ""
   125  }
   126  
   127  func (s *Checker) hasShardTableInTableName(n *ast.TableName) bool {
   128  	db := n.Schema.L
   129  	if db == "" {
   130  		db = s.db
   131  	}
   132  	table := n.Name.L
   133  	_, ok := s.router.GetShardRule(db, table)
   134  	return ok
   135  }
   136  
   137  func (s *Checker) hasShardTableInColumnName(n *ast.ColumnName) bool {
   138  	db := n.Schema.L
   139  	if db == "" {
   140  		db = s.db
   141  	}
   142  	table := n.Table.L
   143  	_, ok := s.router.GetShardRule(db, table)
   144  	return ok
   145  }
   146  
   147  type basePlan struct{}
   148  
   149  func (*basePlan) Size() int {
   150  	return 1
   151  }
   152  
   153  // StmtInfo 各种Plan的一些公共属性
   154  type StmtInfo struct {
   155  	db               string // session db
   156  	sql              string // origin sql
   157  	router           *router.Router
   158  	tableRules       map[string]router.Rule // key = table name, value = router.Rule, 记录使用到的分片表
   159  	globalTableRules map[string]router.Rule // 记录使用到的全局表
   160  	result           *RouteResult
   161  }
   162  
   163  // TableAliasStmtInfo 使用到表别名, 且依赖表别名做路由计算的StmtNode, 目前包括UPDATE, SELECT
   164  // INSERT也可以使用表别名, 但是由于只存在一个表, 可以直接去掉, 因此不需要.
   165  type TableAliasStmtInfo struct {
   166  	*StmtInfo
   167  	tableAlias map[string]string // key = table alias, value = table
   168  	hintPhyDB  string            // 记录mycat分片时DATABASE()函数指定的物理DB名
   169  }
   170  
   171  // BuildPlan build plan for ast
   172  func BuildPlan(stmt ast.StmtNode, phyDBs map[string]string, db, sql string, router *router.Router, seq *sequence.SequenceManager) (Plan, error) {
   173  	if IsSelectLastInsertIDStmt(stmt) {
   174  		return CreateSelectLastInsertIDPlan(), nil
   175  	}
   176  
   177  	if estmt, ok := stmt.(*ast.ExplainStmt); ok {
   178  		return buildExplainPlan(estmt, phyDBs, db, sql, router, seq)
   179  	}
   180  
   181  	checker := NewChecker(db, router)
   182  	stmt.Accept(checker)
   183  
   184  	if checker.IsDatabaseInvalid() {
   185  		return nil, fmt.Errorf("no database selected") // TODO: return standard MySQL error
   186  	}
   187  
   188  	if checker.IsShard() {
   189  		return buildShardPlan(stmt, db, sql, router, seq)
   190  	}
   191  	return CreateUnshardPlan(stmt, phyDBs, db, checker.GetUnshardTableNames())
   192  }
   193  
   194  func buildShardPlan(stmt ast.StmtNode, db string, sql string, router *router.Router, seq *sequence.SequenceManager) (Plan, error) {
   195  	switch s := stmt.(type) {
   196  	case *ast.SelectStmt:
   197  		plan := NewSelectPlan(db, sql, router)
   198  		if err := HandleSelectStmt(plan, s); err != nil {
   199  			return nil, err
   200  		}
   201  		return plan, nil
   202  	case *ast.InsertStmt:
   203  		// InsertStmt contains REPLACE statement
   204  		plan := NewInsertPlan(db, sql, router, seq)
   205  		if err := HandleInsertStmt(plan, s); err != nil {
   206  			return nil, err
   207  		}
   208  		return plan, nil
   209  	case *ast.UpdateStmt:
   210  		plan := NewUpdatePlan(s, db, sql, router)
   211  		if err := HandleUpdatePlan(plan); err != nil {
   212  			return nil, err
   213  		}
   214  		return plan, nil
   215  	case *ast.DeleteStmt:
   216  		plan := NewDeletePlan(s, db, sql, router)
   217  		if err := HandleDeletePlan(plan); err != nil {
   218  			return nil, err
   219  		}
   220  		return plan, nil
   221  	default:
   222  		return nil, fmt.Errorf("stmt type does not support shard now")
   223  	}
   224  }
   225  
   226  // NewStmtInfo constructor of StmtInfo
   227  func NewStmtInfo(db string, sql string, r *router.Router) *StmtInfo {
   228  	return &StmtInfo{
   229  		db:               db,
   230  		sql:              sql,
   231  		router:           r,
   232  		tableRules:       make(map[string]router.Rule),
   233  		globalTableRules: make(map[string]router.Rule),
   234  		result:           NewRouteResult("", "", nil), // nil route result
   235  	}
   236  }
   237  
   238  // NewTableAliasStmtInfo means table alias StmtInfo
   239  func NewTableAliasStmtInfo(db string, sql string, r *router.Router) *TableAliasStmtInfo {
   240  	return &TableAliasStmtInfo{
   241  		StmtInfo:   NewStmtInfo(db, sql, r),
   242  		tableAlias: make(map[string]string),
   243  	}
   244  }
   245  
   246  // GetRouteResult get route result
   247  func (s *StmtInfo) GetRouteResult() *RouteResult {
   248  	return s.result
   249  }
   250  
   251  func (s *StmtInfo) checkAndGetDB(db string) (string, error) {
   252  	if db != "" && db != s.db {
   253  		return "", fmt.Errorf("db not match")
   254  	}
   255  	return s.db, nil
   256  }
   257  
   258  // RecordShardTable 将表信息记录到StmtInfo中, 并返回表信息对应的路由规则
   259  func (s *StmtInfo) RecordShardTable(db, table string) (router.Rule, error) {
   260  	rule, err := s.getShardRule(db, table)
   261  	if err != nil {
   262  		return nil, fmt.Errorf("get shard rule error, db: %s, table: %s, err: %v", db, table, err)
   263  	}
   264  
   265  	if err := s.checkStmtRouteResult(rule); err != nil {
   266  		return nil, fmt.Errorf("check route result error, db: %s, table: %s, err: %v", db, table, err)
   267  	}
   268  
   269  	return rule, nil
   270  }
   271  
   272  // 根据db和table获取Rule
   273  // 如果只传table, 则使用session db.
   274  func (s *StmtInfo) getShardRule(db, table string) (router.Rule, error) {
   275  	validDB, err := s.checkAndGetDB(db)
   276  	if err != nil {
   277  		return nil, err
   278  	}
   279  
   280  	rule, ok := s.router.GetShardRule(validDB, table) // 这里一定是ShardingRule, 不会是DefaultRule
   281  	if !ok {
   282  		return nil, fmt.Errorf("rule not found")
   283  	}
   284  
   285  	if rule.GetType() == router.GlobalTableRuleType {
   286  		s.globalTableRules[table] = rule
   287  	} else {
   288  		s.tableRules[table] = rule // 记录已经使用到的rule
   289  	}
   290  	return rule, nil
   291  }
   292  
   293  // 检查路由规则与现有RouteResult是否一致
   294  // 一致的标准: 与RouteResult的db, table一致
   295  func (s *StmtInfo) checkStmtRouteResult(rule router.Rule) error {
   296  	// 如果是全局表, 不需要检查路由规则是否一致, 只记录该规则, 直接返回即可
   297  	if rule.GetType() == router.GlobalTableRuleType {
   298  		return nil
   299  	}
   300  
   301  	db := rule.GetDB()
   302  	var table string
   303  	if linkedRule, ok := rule.(*router.LinkedRule); ok {
   304  		table = linkedRule.GetParentTable()
   305  	} else {
   306  		table = rule.GetTable()
   307  	}
   308  
   309  	if s.result.db == "" && s.result.table == "" {
   310  		s.result.db = db
   311  		s.result.table = table
   312  		s.result.indexes = rule.GetSubTableIndexes()
   313  	} else {
   314  		if err := s.result.Check(db, table); err != nil {
   315  			return fmt.Errorf("check db and table error: %v", err)
   316  		}
   317  	}
   318  
   319  	return nil
   320  }
   321  
   322  // 用于WHERE条件或JOIN ON条件中, 只存在列名时, 查找对应的路由规则
   323  func (s *StmtInfo) getSettedRuleByColumnName(column string) (router.Rule, bool, error) {
   324  	var columnExistsInShardingTables int // 记录分片表名出现在分片表中分片列的次数
   325  	var ret router.Rule
   326  	for _, r := range s.tableRules {
   327  		if r.GetShardingColumn() == column {
   328  			columnExistsInShardingTables++
   329  			ret = r
   330  		}
   331  	}
   332  
   333  	if columnExistsInShardingTables > 1 {
   334  		return nil, false, fmt.Errorf("column %s is ambiguous for sharding", column)
   335  	}
   336  
   337  	return ret, ret != nil, nil
   338  }
   339  
   340  // 处理SELECT只含有全局表的情况
   341  // 这种情况只路由到默认分片
   342  // 如果有多个全局表, 则只取第一个全局表的配置, 因此需要业务上保证这些全局表的配置是一致的.
   343  func postHandleGlobalTableRouteResultInQuery(p *StmtInfo) error {
   344  	if len(p.tableRules) == 0 && len(p.globalTableRules) != 0 {
   345  		var tableName string
   346  		var rule router.Rule
   347  		for t, r := range p.globalTableRules {
   348  			tableName = t
   349  			rule = r
   350  			break
   351  		}
   352  		p.result.db = rule.GetDB()
   353  		p.result.table = tableName
   354  		p.result.indexes = []int{0} // 全局表SELECT只取默认分片
   355  	}
   356  	return nil
   357  }
   358  
   359  // 处理UPDATE, DELETE只含有全局表的情况
   360  // 这种情况只路由到默认分片
   361  // 如果有多个全局表, 则只取第一个全局表的配置, 因此需要业务上保证这些全局表的配置是一致的.
   362  func postHandleGlobalTableRouteResultInModify(p *StmtInfo) error {
   363  	if len(p.tableRules) == 0 && len(p.globalTableRules) != 0 {
   364  		var tableName string
   365  		var rule router.Rule
   366  		for t, r := range p.globalTableRules {
   367  			tableName = t
   368  			rule = r
   369  			break
   370  		}
   371  		p.result.db = rule.GetDB()
   372  		p.result.table = tableName
   373  		p.result.indexes = rule.GetSubTableIndexes()
   374  	}
   375  	return nil
   376  }
   377  
   378  // RecordSubqueryTableAlias 记录表名位置的子查询的别名, 便于后续处理
   379  // 返回已存在Rule的第一个 (任意一个即可)
   380  // 限制: 子查询中的表对应的路由规则必须与外层查询相关联, 或者为全局表
   381  func (t *TableAliasStmtInfo) RecordSubqueryTableAlias(alias string) (router.Rule, error) {
   382  	if alias == "" {
   383  		return nil, fmt.Errorf("subquery table alias is nil")
   384  	}
   385  
   386  	if len(t.tableRules) == 0 {
   387  		return nil, fmt.Errorf("no explicit table exist except subquery")
   388  	}
   389  
   390  	table := "gaea_subquery_" + alias
   391  	if err := t.setTableAlias(table, alias); err != nil {
   392  		return nil, fmt.Errorf("set subquery table alias error: %v", err)
   393  	}
   394  
   395  	var rule router.Rule
   396  	for _, r := range t.tableRules {
   397  		rule = r
   398  		break
   399  	}
   400  
   401  	t.tableRules[table] = rule
   402  	return rule, nil
   403  }
   404  
   405  // GetSettedRuleFromColumnInfo 用于WHERE条件或JOIN ON条件中, 查找列名对应的路由规则
   406  func (t *TableAliasStmtInfo) GetSettedRuleFromColumnInfo(db, table, column string) (router.Rule, bool, bool, error) {
   407  	if db == "" && table == "" {
   408  		rule, need, err := t.getSettedRuleByColumnName(column)
   409  		return rule, need, false, err
   410  	}
   411  
   412  	rule, isAlias, err := t.getSettedRuleFromTable(db, table)
   413  	return rule, rule != nil, isAlias, err
   414  }
   415  
   416  // 用于WHERE条件或JOIN ON条件中, 只存在列名时, 查找对应的路由规则
   417  func (t *TableAliasStmtInfo) getSettedRuleByColumnName(column string) (router.Rule, bool, error) {
   418  	var columnExistsInShardingTables int // 记录分片表名出现在分片表中分片列的次数
   419  	var ret router.Rule
   420  	for _, r := range t.tableRules {
   421  		if r.GetShardingColumn() == column {
   422  			columnExistsInShardingTables++
   423  			ret = r
   424  		}
   425  	}
   426  
   427  	if columnExistsInShardingTables > 1 {
   428  		return nil, false, fmt.Errorf("column %s is ambiguous for sharding", column)
   429  	}
   430  
   431  	return ret, ret != nil, nil
   432  }
   433  
   434  // 获取FROM TABLE列表中的表数据
   435  // 用于FieldList和Where条件中列名的判断
   436  func (t *TableAliasStmtInfo) getSettedRuleFromTable(db, table string) (router.Rule, bool, error) {
   437  	_, err := t.checkAndGetDB(db)
   438  	if err != nil {
   439  		return nil, false, err
   440  	}
   441  	if rule, ok := t.tableRules[table]; ok {
   442  		return rule, false, nil
   443  	}
   444  
   445  	if rule, ok := t.globalTableRules[table]; ok {
   446  		return rule, false, nil
   447  	}
   448  
   449  	if originTable, ok := t.getAliasTable(table); ok {
   450  		if rule, ok := t.tableRules[originTable]; ok {
   451  			return rule, true, nil
   452  		}
   453  		if rule, ok := t.globalTableRules[originTable]; ok {
   454  			return rule, true, nil
   455  		}
   456  	}
   457  
   458  	return nil, false, fmt.Errorf("rule not found")
   459  }
   460  
   461  // RecordShardTable 将表信息记录到StmtInfo中, 并返回表信息对应的路由规则
   462  func (t *TableAliasStmtInfo) RecordShardTable(db, table, alias string) (router.Rule, error) {
   463  	rule, err := t.StmtInfo.RecordShardTable(db, table)
   464  	if err != nil {
   465  		return nil, fmt.Errorf("record shard table error, db: %s, table: %s, alias: %s, err: %v", db, table, alias, err)
   466  	}
   467  
   468  	if alias != "" {
   469  		if err := t.setTableAlias(table, alias); err != nil {
   470  			return nil, fmt.Errorf("set table alias error: %v", err)
   471  		}
   472  	}
   473  
   474  	return rule, nil
   475  }
   476  
   477  func (t *TableAliasStmtInfo) setTableAlias(table, alias string) error {
   478  	// if not set, set without check
   479  	originTable, ok := t.tableAlias[alias]
   480  	if !ok {
   481  		t.tableAlias[alias] = table
   482  		return nil
   483  	}
   484  
   485  	if originTable != table {
   486  		return fmt.Errorf("table alias is set but not match, table: %s, originTable: %s", table, originTable)
   487  	}
   488  
   489  	// already set, return
   490  	return nil
   491  }
   492  
   493  func (t *TableAliasStmtInfo) getAliasTable(alias string) (string, bool) {
   494  	table, ok := t.tableAlias[alias]
   495  	return table, ok
   496  }
   497  
   498  // 根据StmtNode和路由信息生成分片SQL
   499  func generateShardingSQLs(stmt ast.StmtNode, result *RouteResult, router *router.Router) (map[string]map[string][]string, error) {
   500  	ret := make(map[string]map[string][]string)
   501  
   502  	for result.HasNext() {
   503  		sb := &strings.Builder{}
   504  		ctx := format.NewRestoreCtx(format.EscapeRestoreFlags, sb)
   505  		if err := stmt.Restore(ctx); err != nil {
   506  			return nil, err
   507  		}
   508  
   509  		index := result.Next()
   510  		rule, ok := router.GetShardRule(result.db, result.table)
   511  		if !ok {
   512  			return nil, fmt.Errorf("cannot find shard rule, db: %s, table: %s", result.db, result.table)
   513  		}
   514  		sliceIndex := rule.GetSliceIndexFromTableIndex(index)
   515  		sliceName := rule.GetSlice(sliceIndex)
   516  		dbName, _ := rule.GetDatabaseNameByTableIndex(index)
   517  		sliceSQLs, ok := ret[sliceName]
   518  		if !ok {
   519  			sliceSQLs = make(map[string][]string)
   520  			ret[sliceName] = sliceSQLs
   521  		}
   522  
   523  		ret[sliceName][dbName] = append(ret[sliceName][dbName], sb.String())
   524  	}
   525  
   526  	result.Reset() // must reset the cursor for next call
   527  
   528  	return ret, nil
   529  }
   530  
   531  // 根据原始SQL生成后端对应slice和db的SQL
   532  func generateSQLResultFromOriginSQL(sql string, result *RouteResult, router *router.Router) (map[string]map[string][]string, error) {
   533  	rule := router.GetRule(result.db, result.table)
   534  	indexes := rule.GetSubTableIndexes()
   535  	ret := make(map[string]map[string][]string)
   536  	for _, index := range indexes {
   537  		sliceIndex := rule.GetSliceIndexFromTableIndex(index)
   538  		sliceName := rule.GetSlice(sliceIndex)
   539  		dbName, _ := rule.GetDatabaseNameByTableIndex(index)
   540  		sliceSQLs, ok := ret[sliceName]
   541  		if !ok {
   542  			sliceSQLs = make(map[string][]string)
   543  			ret[sliceName] = sliceSQLs
   544  		}
   545  
   546  		ret[sliceName][dbName] = append(ret[sliceName][dbName], sql)
   547  	}
   548  
   549  	return ret, nil
   550  }
   551  
   552  // copy from newEmptyResultset
   553  // 注意去掉补充的列
   554  func newEmptyResultset(info *SelectPlan, stmt *ast.SelectStmt) *mysql.Resultset {
   555  	r := new(mysql.Resultset)
   556  
   557  	fieldLen := len(stmt.Fields.Fields)
   558  	fieldLen -= info.columnCount - info.originColumnCount
   559  
   560  	r.Fields = make([]*mysql.Field, fieldLen)
   561  	for i, expr := range stmt.Fields.Fields {
   562  		r.Fields[i] = &mysql.Field{}
   563  		if expr.WildCard != nil {
   564  			r.Fields[i].Name = []byte("*")
   565  		} else {
   566  			if expr.AsName.String() != "" {
   567  				r.Fields[i].Name = hack.Slice(expr.AsName.String())
   568  				name, _ := parser.NodeToStringWithoutQuote(expr.Expr)
   569  				r.Fields[i].OrgName = hack.Slice(name)
   570  			} else {
   571  				name, _ := parser.NodeToStringWithoutQuote(expr.Expr)
   572  				r.Fields[i].Name = hack.Slice(name)
   573  			}
   574  		}
   575  	}
   576  
   577  	r.Values = make([][]interface{}, 0)
   578  	r.RowDatas = make([]mysql.RowData, 0)
   579  
   580  	return r
   581  }