github.com/XiaoMi/Gaea@v1.2.5/proxy/plan/plan.go (about) 1 // Copyright 2019 The Gaea Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package plan 16 17 import ( 18 "fmt" 19 "strings" 20 21 "github.com/XiaoMi/Gaea/mysql" 22 "github.com/XiaoMi/Gaea/parser" 23 "github.com/XiaoMi/Gaea/parser/ast" 24 "github.com/XiaoMi/Gaea/parser/format" 25 "github.com/XiaoMi/Gaea/proxy/router" 26 "github.com/XiaoMi/Gaea/proxy/sequence" 27 "github.com/XiaoMi/Gaea/util" 28 "github.com/XiaoMi/Gaea/util/hack" 29 ) 30 31 // type check 32 var _ Plan = &UnshardPlan{} 33 var _ Plan = &SelectPlan{} 34 var _ Plan = &DeletePlan{} 35 var _ Plan = &UpdatePlan{} 36 var _ Plan = &InsertPlan{} 37 var _ Plan = &SelectLastInsertIDPlan{} 38 39 // Plan is a interface for select/insert etc. 40 type Plan interface { 41 ExecuteIn(*util.RequestContext, Executor) (*mysql.Result, error) 42 43 // only for cache 44 Size() int 45 } 46 47 // Executor TODO: move to package executor 48 type Executor interface { 49 50 // 执行分片或非分片单条SQL 51 ExecuteSQL(ctx *util.RequestContext, slice, db, sql string) (*mysql.Result, error) 52 53 // 执行分片SQL 54 ExecuteSQLs(*util.RequestContext, map[string]map[string][]string) ([]*mysql.Result, error) 55 56 // 用于执行INSERT时设置last insert id 57 SetLastInsertID(uint64) 58 59 GetLastInsertID() uint64 60 } 61 62 // Checker 用于检查SelectStmt是不是分表的Visitor, 以及是否包含DB信息 63 type Checker struct { 64 db string 65 router *router.Router 66 hasShardTable bool // 是否包含分片表 67 dbInvalid bool // SQL是否No database selected 68 tableNames []*ast.TableName 69 } 70 71 // NewChecker db为USE db中设置的DB名. 如果没有执行USE db, 则为空字符串 72 func NewChecker(db string, router *router.Router) *Checker { 73 return &Checker{ 74 db: db, 75 router: router, 76 hasShardTable: false, 77 dbInvalid: false, 78 } 79 } 80 81 func (s *Checker) GetUnshardTableNames() []*ast.TableName { 82 return s.tableNames 83 } 84 85 // IsDatabaseInvalid 判断执行计划中是否包含db信息, 如果不包含, 且又含有表名, 则是一个错的执行计划, 应该返回以下错误: 86 // ERROR 1046 (3D000): No database selected 87 func (s *Checker) IsDatabaseInvalid() bool { 88 return s.dbInvalid 89 } 90 91 // IsShard if is shard table 92 func (s *Checker) IsShard() bool { 93 return s.hasShardTable 94 } 95 96 // Enter for node visit 97 func (s *Checker) Enter(n ast.Node) (node ast.Node, skipChildren bool) { 98 if s.hasShardTable { 99 return n, true 100 } 101 switch nn := n.(type) { 102 case *ast.TableName: 103 if s.isTableNameDatabaseInvalid(nn) { 104 s.dbInvalid = true 105 return n, true 106 } 107 has := s.hasShardTableInTableName(nn) 108 if has { 109 s.hasShardTable = true 110 return n, true 111 } 112 s.tableNames = append(s.tableNames, nn) 113 } 114 return n, false 115 } 116 117 // Leave for node visit 118 func (s *Checker) Leave(n ast.Node) (node ast.Node, ok bool) { 119 return n, !s.dbInvalid && !s.hasShardTable 120 } 121 122 // 如果ast.TableName不带DB名, 且Session未设置DB, 则是不允许的SQL, 应该返回No database selected 123 func (s *Checker) isTableNameDatabaseInvalid(n *ast.TableName) bool { 124 return s.db == "" && n.Schema.L == "" 125 } 126 127 func (s *Checker) hasShardTableInTableName(n *ast.TableName) bool { 128 db := n.Schema.L 129 if db == "" { 130 db = s.db 131 } 132 table := n.Name.L 133 _, ok := s.router.GetShardRule(db, table) 134 return ok 135 } 136 137 func (s *Checker) hasShardTableInColumnName(n *ast.ColumnName) bool { 138 db := n.Schema.L 139 if db == "" { 140 db = s.db 141 } 142 table := n.Table.L 143 _, ok := s.router.GetShardRule(db, table) 144 return ok 145 } 146 147 type basePlan struct{} 148 149 func (*basePlan) Size() int { 150 return 1 151 } 152 153 // StmtInfo 各种Plan的一些公共属性 154 type StmtInfo struct { 155 db string // session db 156 sql string // origin sql 157 router *router.Router 158 tableRules map[string]router.Rule // key = table name, value = router.Rule, 记录使用到的分片表 159 globalTableRules map[string]router.Rule // 记录使用到的全局表 160 result *RouteResult 161 } 162 163 // TableAliasStmtInfo 使用到表别名, 且依赖表别名做路由计算的StmtNode, 目前包括UPDATE, SELECT 164 // INSERT也可以使用表别名, 但是由于只存在一个表, 可以直接去掉, 因此不需要. 165 type TableAliasStmtInfo struct { 166 *StmtInfo 167 tableAlias map[string]string // key = table alias, value = table 168 hintPhyDB string // 记录mycat分片时DATABASE()函数指定的物理DB名 169 } 170 171 // BuildPlan build plan for ast 172 func BuildPlan(stmt ast.StmtNode, phyDBs map[string]string, db, sql string, router *router.Router, seq *sequence.SequenceManager) (Plan, error) { 173 if IsSelectLastInsertIDStmt(stmt) { 174 return CreateSelectLastInsertIDPlan(), nil 175 } 176 177 if estmt, ok := stmt.(*ast.ExplainStmt); ok { 178 return buildExplainPlan(estmt, phyDBs, db, sql, router, seq) 179 } 180 181 checker := NewChecker(db, router) 182 stmt.Accept(checker) 183 184 if checker.IsDatabaseInvalid() { 185 return nil, fmt.Errorf("no database selected") // TODO: return standard MySQL error 186 } 187 188 if checker.IsShard() { 189 return buildShardPlan(stmt, db, sql, router, seq) 190 } 191 return CreateUnshardPlan(stmt, phyDBs, db, checker.GetUnshardTableNames()) 192 } 193 194 func buildShardPlan(stmt ast.StmtNode, db string, sql string, router *router.Router, seq *sequence.SequenceManager) (Plan, error) { 195 switch s := stmt.(type) { 196 case *ast.SelectStmt: 197 plan := NewSelectPlan(db, sql, router) 198 if err := HandleSelectStmt(plan, s); err != nil { 199 return nil, err 200 } 201 return plan, nil 202 case *ast.InsertStmt: 203 // InsertStmt contains REPLACE statement 204 plan := NewInsertPlan(db, sql, router, seq) 205 if err := HandleInsertStmt(plan, s); err != nil { 206 return nil, err 207 } 208 return plan, nil 209 case *ast.UpdateStmt: 210 plan := NewUpdatePlan(s, db, sql, router) 211 if err := HandleUpdatePlan(plan); err != nil { 212 return nil, err 213 } 214 return plan, nil 215 case *ast.DeleteStmt: 216 plan := NewDeletePlan(s, db, sql, router) 217 if err := HandleDeletePlan(plan); err != nil { 218 return nil, err 219 } 220 return plan, nil 221 default: 222 return nil, fmt.Errorf("stmt type does not support shard now") 223 } 224 } 225 226 // NewStmtInfo constructor of StmtInfo 227 func NewStmtInfo(db string, sql string, r *router.Router) *StmtInfo { 228 return &StmtInfo{ 229 db: db, 230 sql: sql, 231 router: r, 232 tableRules: make(map[string]router.Rule), 233 globalTableRules: make(map[string]router.Rule), 234 result: NewRouteResult("", "", nil), // nil route result 235 } 236 } 237 238 // NewTableAliasStmtInfo means table alias StmtInfo 239 func NewTableAliasStmtInfo(db string, sql string, r *router.Router) *TableAliasStmtInfo { 240 return &TableAliasStmtInfo{ 241 StmtInfo: NewStmtInfo(db, sql, r), 242 tableAlias: make(map[string]string), 243 } 244 } 245 246 // GetRouteResult get route result 247 func (s *StmtInfo) GetRouteResult() *RouteResult { 248 return s.result 249 } 250 251 func (s *StmtInfo) checkAndGetDB(db string) (string, error) { 252 if db != "" && db != s.db { 253 return "", fmt.Errorf("db not match") 254 } 255 return s.db, nil 256 } 257 258 // RecordShardTable 将表信息记录到StmtInfo中, 并返回表信息对应的路由规则 259 func (s *StmtInfo) RecordShardTable(db, table string) (router.Rule, error) { 260 rule, err := s.getShardRule(db, table) 261 if err != nil { 262 return nil, fmt.Errorf("get shard rule error, db: %s, table: %s, err: %v", db, table, err) 263 } 264 265 if err := s.checkStmtRouteResult(rule); err != nil { 266 return nil, fmt.Errorf("check route result error, db: %s, table: %s, err: %v", db, table, err) 267 } 268 269 return rule, nil 270 } 271 272 // 根据db和table获取Rule 273 // 如果只传table, 则使用session db. 274 func (s *StmtInfo) getShardRule(db, table string) (router.Rule, error) { 275 validDB, err := s.checkAndGetDB(db) 276 if err != nil { 277 return nil, err 278 } 279 280 rule, ok := s.router.GetShardRule(validDB, table) // 这里一定是ShardingRule, 不会是DefaultRule 281 if !ok { 282 return nil, fmt.Errorf("rule not found") 283 } 284 285 if rule.GetType() == router.GlobalTableRuleType { 286 s.globalTableRules[table] = rule 287 } else { 288 s.tableRules[table] = rule // 记录已经使用到的rule 289 } 290 return rule, nil 291 } 292 293 // 检查路由规则与现有RouteResult是否一致 294 // 一致的标准: 与RouteResult的db, table一致 295 func (s *StmtInfo) checkStmtRouteResult(rule router.Rule) error { 296 // 如果是全局表, 不需要检查路由规则是否一致, 只记录该规则, 直接返回即可 297 if rule.GetType() == router.GlobalTableRuleType { 298 return nil 299 } 300 301 db := rule.GetDB() 302 var table string 303 if linkedRule, ok := rule.(*router.LinkedRule); ok { 304 table = linkedRule.GetParentTable() 305 } else { 306 table = rule.GetTable() 307 } 308 309 if s.result.db == "" && s.result.table == "" { 310 s.result.db = db 311 s.result.table = table 312 s.result.indexes = rule.GetSubTableIndexes() 313 } else { 314 if err := s.result.Check(db, table); err != nil { 315 return fmt.Errorf("check db and table error: %v", err) 316 } 317 } 318 319 return nil 320 } 321 322 // 用于WHERE条件或JOIN ON条件中, 只存在列名时, 查找对应的路由规则 323 func (s *StmtInfo) getSettedRuleByColumnName(column string) (router.Rule, bool, error) { 324 var columnExistsInShardingTables int // 记录分片表名出现在分片表中分片列的次数 325 var ret router.Rule 326 for _, r := range s.tableRules { 327 if r.GetShardingColumn() == column { 328 columnExistsInShardingTables++ 329 ret = r 330 } 331 } 332 333 if columnExistsInShardingTables > 1 { 334 return nil, false, fmt.Errorf("column %s is ambiguous for sharding", column) 335 } 336 337 return ret, ret != nil, nil 338 } 339 340 // 处理SELECT只含有全局表的情况 341 // 这种情况只路由到默认分片 342 // 如果有多个全局表, 则只取第一个全局表的配置, 因此需要业务上保证这些全局表的配置是一致的. 343 func postHandleGlobalTableRouteResultInQuery(p *StmtInfo) error { 344 if len(p.tableRules) == 0 && len(p.globalTableRules) != 0 { 345 var tableName string 346 var rule router.Rule 347 for t, r := range p.globalTableRules { 348 tableName = t 349 rule = r 350 break 351 } 352 p.result.db = rule.GetDB() 353 p.result.table = tableName 354 p.result.indexes = []int{0} // 全局表SELECT只取默认分片 355 } 356 return nil 357 } 358 359 // 处理UPDATE, DELETE只含有全局表的情况 360 // 这种情况只路由到默认分片 361 // 如果有多个全局表, 则只取第一个全局表的配置, 因此需要业务上保证这些全局表的配置是一致的. 362 func postHandleGlobalTableRouteResultInModify(p *StmtInfo) error { 363 if len(p.tableRules) == 0 && len(p.globalTableRules) != 0 { 364 var tableName string 365 var rule router.Rule 366 for t, r := range p.globalTableRules { 367 tableName = t 368 rule = r 369 break 370 } 371 p.result.db = rule.GetDB() 372 p.result.table = tableName 373 p.result.indexes = rule.GetSubTableIndexes() 374 } 375 return nil 376 } 377 378 // RecordSubqueryTableAlias 记录表名位置的子查询的别名, 便于后续处理 379 // 返回已存在Rule的第一个 (任意一个即可) 380 // 限制: 子查询中的表对应的路由规则必须与外层查询相关联, 或者为全局表 381 func (t *TableAliasStmtInfo) RecordSubqueryTableAlias(alias string) (router.Rule, error) { 382 if alias == "" { 383 return nil, fmt.Errorf("subquery table alias is nil") 384 } 385 386 if len(t.tableRules) == 0 { 387 return nil, fmt.Errorf("no explicit table exist except subquery") 388 } 389 390 table := "gaea_subquery_" + alias 391 if err := t.setTableAlias(table, alias); err != nil { 392 return nil, fmt.Errorf("set subquery table alias error: %v", err) 393 } 394 395 var rule router.Rule 396 for _, r := range t.tableRules { 397 rule = r 398 break 399 } 400 401 t.tableRules[table] = rule 402 return rule, nil 403 } 404 405 // GetSettedRuleFromColumnInfo 用于WHERE条件或JOIN ON条件中, 查找列名对应的路由规则 406 func (t *TableAliasStmtInfo) GetSettedRuleFromColumnInfo(db, table, column string) (router.Rule, bool, bool, error) { 407 if db == "" && table == "" { 408 rule, need, err := t.getSettedRuleByColumnName(column) 409 return rule, need, false, err 410 } 411 412 rule, isAlias, err := t.getSettedRuleFromTable(db, table) 413 return rule, rule != nil, isAlias, err 414 } 415 416 // 用于WHERE条件或JOIN ON条件中, 只存在列名时, 查找对应的路由规则 417 func (t *TableAliasStmtInfo) getSettedRuleByColumnName(column string) (router.Rule, bool, error) { 418 var columnExistsInShardingTables int // 记录分片表名出现在分片表中分片列的次数 419 var ret router.Rule 420 for _, r := range t.tableRules { 421 if r.GetShardingColumn() == column { 422 columnExistsInShardingTables++ 423 ret = r 424 } 425 } 426 427 if columnExistsInShardingTables > 1 { 428 return nil, false, fmt.Errorf("column %s is ambiguous for sharding", column) 429 } 430 431 return ret, ret != nil, nil 432 } 433 434 // 获取FROM TABLE列表中的表数据 435 // 用于FieldList和Where条件中列名的判断 436 func (t *TableAliasStmtInfo) getSettedRuleFromTable(db, table string) (router.Rule, bool, error) { 437 _, err := t.checkAndGetDB(db) 438 if err != nil { 439 return nil, false, err 440 } 441 if rule, ok := t.tableRules[table]; ok { 442 return rule, false, nil 443 } 444 445 if rule, ok := t.globalTableRules[table]; ok { 446 return rule, false, nil 447 } 448 449 if originTable, ok := t.getAliasTable(table); ok { 450 if rule, ok := t.tableRules[originTable]; ok { 451 return rule, true, nil 452 } 453 if rule, ok := t.globalTableRules[originTable]; ok { 454 return rule, true, nil 455 } 456 } 457 458 return nil, false, fmt.Errorf("rule not found") 459 } 460 461 // RecordShardTable 将表信息记录到StmtInfo中, 并返回表信息对应的路由规则 462 func (t *TableAliasStmtInfo) RecordShardTable(db, table, alias string) (router.Rule, error) { 463 rule, err := t.StmtInfo.RecordShardTable(db, table) 464 if err != nil { 465 return nil, fmt.Errorf("record shard table error, db: %s, table: %s, alias: %s, err: %v", db, table, alias, err) 466 } 467 468 if alias != "" { 469 if err := t.setTableAlias(table, alias); err != nil { 470 return nil, fmt.Errorf("set table alias error: %v", err) 471 } 472 } 473 474 return rule, nil 475 } 476 477 func (t *TableAliasStmtInfo) setTableAlias(table, alias string) error { 478 // if not set, set without check 479 originTable, ok := t.tableAlias[alias] 480 if !ok { 481 t.tableAlias[alias] = table 482 return nil 483 } 484 485 if originTable != table { 486 return fmt.Errorf("table alias is set but not match, table: %s, originTable: %s", table, originTable) 487 } 488 489 // already set, return 490 return nil 491 } 492 493 func (t *TableAliasStmtInfo) getAliasTable(alias string) (string, bool) { 494 table, ok := t.tableAlias[alias] 495 return table, ok 496 } 497 498 // 根据StmtNode和路由信息生成分片SQL 499 func generateShardingSQLs(stmt ast.StmtNode, result *RouteResult, router *router.Router) (map[string]map[string][]string, error) { 500 ret := make(map[string]map[string][]string) 501 502 for result.HasNext() { 503 sb := &strings.Builder{} 504 ctx := format.NewRestoreCtx(format.EscapeRestoreFlags, sb) 505 if err := stmt.Restore(ctx); err != nil { 506 return nil, err 507 } 508 509 index := result.Next() 510 rule, ok := router.GetShardRule(result.db, result.table) 511 if !ok { 512 return nil, fmt.Errorf("cannot find shard rule, db: %s, table: %s", result.db, result.table) 513 } 514 sliceIndex := rule.GetSliceIndexFromTableIndex(index) 515 sliceName := rule.GetSlice(sliceIndex) 516 dbName, _ := rule.GetDatabaseNameByTableIndex(index) 517 sliceSQLs, ok := ret[sliceName] 518 if !ok { 519 sliceSQLs = make(map[string][]string) 520 ret[sliceName] = sliceSQLs 521 } 522 523 ret[sliceName][dbName] = append(ret[sliceName][dbName], sb.String()) 524 } 525 526 result.Reset() // must reset the cursor for next call 527 528 return ret, nil 529 } 530 531 // 根据原始SQL生成后端对应slice和db的SQL 532 func generateSQLResultFromOriginSQL(sql string, result *RouteResult, router *router.Router) (map[string]map[string][]string, error) { 533 rule := router.GetRule(result.db, result.table) 534 indexes := rule.GetSubTableIndexes() 535 ret := make(map[string]map[string][]string) 536 for _, index := range indexes { 537 sliceIndex := rule.GetSliceIndexFromTableIndex(index) 538 sliceName := rule.GetSlice(sliceIndex) 539 dbName, _ := rule.GetDatabaseNameByTableIndex(index) 540 sliceSQLs, ok := ret[sliceName] 541 if !ok { 542 sliceSQLs = make(map[string][]string) 543 ret[sliceName] = sliceSQLs 544 } 545 546 ret[sliceName][dbName] = append(ret[sliceName][dbName], sql) 547 } 548 549 return ret, nil 550 } 551 552 // copy from newEmptyResultset 553 // 注意去掉补充的列 554 func newEmptyResultset(info *SelectPlan, stmt *ast.SelectStmt) *mysql.Resultset { 555 r := new(mysql.Resultset) 556 557 fieldLen := len(stmt.Fields.Fields) 558 fieldLen -= info.columnCount - info.originColumnCount 559 560 r.Fields = make([]*mysql.Field, fieldLen) 561 for i, expr := range stmt.Fields.Fields { 562 r.Fields[i] = &mysql.Field{} 563 if expr.WildCard != nil { 564 r.Fields[i].Name = []byte("*") 565 } else { 566 if expr.AsName.String() != "" { 567 r.Fields[i].Name = hack.Slice(expr.AsName.String()) 568 name, _ := parser.NodeToStringWithoutQuote(expr.Expr) 569 r.Fields[i].OrgName = hack.Slice(name) 570 } else { 571 name, _ := parser.NodeToStringWithoutQuote(expr.Expr) 572 r.Fields[i].Name = hack.Slice(name) 573 } 574 } 575 } 576 577 r.Values = make([][]interface{}, 0) 578 r.RowDatas = make([]mysql.RowData, 0) 579 580 return r 581 }