github.com/XiaoMi/Gaea@v1.2.5/proxy/plan/plan_select.go (about) 1 // Copyright 2019 The Gaea Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package plan 16 17 import ( 18 "fmt" 19 20 "github.com/XiaoMi/Gaea/mysql" 21 "github.com/XiaoMi/Gaea/parser/ast" 22 "github.com/XiaoMi/Gaea/parser/opcode" 23 driver "github.com/XiaoMi/Gaea/parser/tidb-types/parser_driver" 24 "github.com/XiaoMi/Gaea/proxy/router" 25 "github.com/XiaoMi/Gaea/util" 26 ) 27 28 // SelectPlan is the plan for select statement 29 type SelectPlan struct { 30 basePlan 31 *TableAliasStmtInfo 32 33 stmt *ast.SelectStmt 34 35 distinct bool // 是否是SELECT DISTINCT 36 groupByColumn []int // GROUP BY 列索引 37 orderByColumn []int // ORDER BY 列索引 38 orderByDirections []bool // ORDER BY 方向, true: DESC 39 originColumnCount int // 补列前的列长度 40 columnCount int // 补列后的列长度 41 42 aggregateFuncs map[int]AggregateFuncMerger // key = column index 43 44 offset int64 // LIMIT offset 45 count int64 // LIMIT count, 未设置则为-1 46 47 sqls map[string]map[string][]string 48 } 49 50 // NewSelectPlan constructor of SelectPlan 51 // db is the session db 52 func NewSelectPlan(db string, sql string, r *router.Router) *SelectPlan { 53 return &SelectPlan{ 54 TableAliasStmtInfo: NewTableAliasStmtInfo(db, sql, r), 55 aggregateFuncs: make(map[int]AggregateFuncMerger), 56 offset: -1, 57 count: -1, 58 } 59 } 60 61 // ExecuteIn implement Plan 62 func (s *SelectPlan) ExecuteIn(reqCtx *util.RequestContext, sess Executor) (*mysql.Result, error) { 63 sqls := s.GetSQLs() 64 if sqls == nil { 65 return nil, fmt.Errorf("SQL has not generated") 66 } 67 68 if len(sqls) == 0 { 69 r := newEmptyResultset(s, s.GetStmt()) 70 ret := &mysql.Result{ 71 Resultset: r, 72 } 73 return ret, nil 74 } 75 76 rs, err := sess.ExecuteSQLs(reqCtx, sqls) 77 if err != nil { 78 return nil, fmt.Errorf("execute in SelectPlan error: %v", err) 79 } 80 81 if s.isExecOnSingleNode() { 82 return rs[0], nil 83 } else { 84 r, err := MergeSelectResult(s, s.stmt, rs) 85 if err != nil { 86 return nil, fmt.Errorf("merge select result error: %v", err) 87 } 88 return r, nil 89 } 90 } 91 92 // GetStmt SelectStmt 93 func (s *SelectPlan) GetStmt() *ast.SelectStmt { 94 return s.stmt 95 } 96 97 func (s *SelectPlan) setAggregateFuncMerger(idx int, merger AggregateFuncMerger) error { 98 if _, ok := s.aggregateFuncs[idx]; ok { 99 return fmt.Errorf("column %d already set", idx) 100 } 101 s.aggregateFuncs[idx] = merger 102 return nil 103 } 104 105 // HasLimit if the select statement has limit clause, return true 106 func (s *SelectPlan) HasLimit() bool { 107 return s.count != -1 108 } 109 110 // GetLimitValue get offset, count in limit clause 111 func (s *SelectPlan) GetLimitValue() (int64, int64) { 112 return s.offset, s.count 113 } 114 115 // HasGroupBy if the select statement has group by clause, return true 116 func (s *SelectPlan) HasGroupBy() bool { 117 return len(s.groupByColumn) != 0 118 } 119 120 // GetOriginColumnCount get origin column count in statement, 121 // since group by and order by may add extra columns to FieldList. 122 func (s *SelectPlan) GetOriginColumnCount() int { 123 return s.originColumnCount 124 } 125 126 // GetColumnCount get column count with extra columns 127 func (s *SelectPlan) GetColumnCount() int { 128 return s.columnCount 129 } 130 131 // GetGroupByColumnInfo get extra column offset and length for group by 132 func (s *SelectPlan) GetGroupByColumnInfo() []int { 133 return s.groupByColumn 134 } 135 136 // HasOrderBy if select statement has order by clause, return true 137 func (s *SelectPlan) HasOrderBy() bool { 138 return len(s.orderByDirections) != 0 139 } 140 141 // GetOrderByColumnInfo get extra column offset and length for order by 142 func (s *SelectPlan) GetOrderByColumnInfo() ([]int, []bool) { 143 return s.orderByColumn, s.orderByDirections 144 } 145 146 // GetSQLs get generated SQLs 147 // the first key is slice, the second key is backend database name, the value is sql list. 148 func (s *SelectPlan) GetSQLs() map[string]map[string][]string { 149 return s.sqls 150 } 151 152 //执行计划是否仅仅涉及一个分片 153 func (s *SelectPlan) isExecOnSingleNode() bool { 154 if len(s.result.indexes) == 1 { 155 return true 156 } else { 157 return false 158 } 159 } 160 161 // HandleSelectStmt build a SelectPlan 162 // 处理SelectStmt语法树, 改写其中一些节点, 并获取路由信息和结果聚合函数 163 func HandleSelectStmt(p *SelectPlan, stmt *ast.SelectStmt) error { 164 p.stmt = stmt // hold the reference of stmt 165 166 p.distinct = stmt.Distinct 167 168 if err := handleTableRefs(p, stmt); err != nil { 169 return fmt.Errorf("handle From error: %v", err) 170 } 171 172 // field list的处理必须在group by之前, 因为group by, order by会补列, 而这些补充的列是已经处理过的 173 if stmt.Fields != nil { 174 if err := handleFieldList(p, stmt); err != nil { 175 return fmt.Errorf("handle Fields error: %v", err) 176 } 177 178 // 记录补列前的Fields长度 179 p.originColumnCount = len(stmt.Fields.Fields) 180 } 181 182 if err := handleWhere(p, stmt); err != nil { 183 return fmt.Errorf("handle Where error: %v", err) 184 } 185 186 if err := postHandleHintDatabaseFunction(p); err != nil { 187 return fmt.Errorf("handle Hint error: %v", err) 188 } 189 190 191 //如果是多节点执行,特殊处理groupby orderby having limit 192 if !p.isExecOnSingleNode() { 193 // group by的处理必须在table处理之后 194 if err := handleGroupBy(p, stmt); err != nil { 195 return fmt.Errorf("handle GroupBy error: %v", err) 196 } 197 198 // order by的处理必须在table处理之后 199 // 与group by补列的顺序没有要求, 只要保证处理返回结果去掉这些补充列时保持相反的顺序, 这里放在group by之后 200 if err := handleOrderBy(p, stmt); err != nil { 201 return fmt.Errorf("handle OrderBy error: %v", err) 202 } 203 204 handleExtraFieldList(p, stmt) 205 206 // 记录补列后的Fields长度, 后面的handler不会补列了 207 if stmt.Fields != nil { 208 p.columnCount = len(stmt.Fields.Fields) 209 } 210 211 212 213 if err := handleHaving(p, stmt); err != nil { 214 return fmt.Errorf("handle Having error: %v", err) 215 } 216 217 if err := handleLimit(p, stmt); err != nil { 218 return fmt.Errorf("handle Limit error: %v", err) 219 } 220 } 221 222 if err := postHandleGlobalTableRouteResultInQuery(p.StmtInfo); err != nil { 223 return fmt.Errorf("post handle global table error: %v", err) 224 } 225 226 sqls, err := generateShardingSQLs(p.stmt, p.result, p.router) 227 if err != nil { 228 return fmt.Errorf("generate select SQL error: %v", err) 229 } 230 231 p.sqls = sqls 232 233 return nil 234 } 235 236 // 处理GroupBy, 把GroupBy的列补到FieldList中, 然后把GroupBy去掉 237 func handleGroupBy(p *SelectPlan, stmt *ast.SelectStmt) error { 238 if stmt.GroupBy == nil { 239 return nil 240 } 241 242 groupByFields, err := createSelectFieldsFromByItems(p, stmt.GroupBy.Items) 243 if err != nil { 244 return fmt.Errorf("get group by fields error: %v", err) 245 } 246 247 for i := 0; i < len(groupByFields); i++ { 248 p.groupByColumn = append(p.groupByColumn, i+len(stmt.Fields.Fields)) 249 } 250 251 // append group by fields 252 stmt.Fields.Fields = append(stmt.Fields.Fields, groupByFields...) 253 254 return nil 255 } 256 257 func handleOrderBy(p *SelectPlan, stmt *ast.SelectStmt) error { 258 if stmt.OrderBy == nil { 259 return nil 260 } 261 262 orderByFields, err := createSelectFieldsFromByItems(p, stmt.OrderBy.Items) 263 if err != nil { 264 return fmt.Errorf("get order by fields error: %v", err) 265 } 266 267 for i := 0; i < len(orderByFields); i++ { 268 p.orderByColumn = append(p.orderByColumn, i+len(stmt.Fields.Fields)) 269 } 270 271 for _, f := range stmt.OrderBy.Items { 272 p.orderByDirections = append(p.orderByDirections, f.Desc) 273 } 274 275 stmt.Fields.Fields = append(stmt.Fields.Fields, orderByFields...) 276 return nil 277 } 278 279 func handleExtraFieldList(p *SelectPlan, stmt *ast.SelectStmt) { 280 selectFields := make(map[string]int) 281 for i := 0; i < p.originColumnCount; i++ { 282 field := stmt.Fields.Fields[i] 283 if field.AsName.L != "" { 284 selectFields[field.AsName.L] = i 285 } 286 if field, isColumnExpr := stmt.Fields.Fields[i].Expr.(*ast.ColumnNameExpr); isColumnExpr { 287 selectFields[field.Name.Name.L] = i 288 } 289 } 290 291 deleteNum := 0 292 for i := 0; i < len(p.groupByColumn); i++ { 293 p.groupByColumn[i] -= deleteNum 294 currColumnIndex := p.originColumnCount + i - deleteNum 295 field, isColumnExpr := stmt.Fields.Fields[currColumnIndex].Expr.(*ast.ColumnNameExpr) 296 if !isColumnExpr { 297 continue 298 } 299 if index, ok := selectFields[field.Name.Name.L]; !ok { 300 continue 301 } else { 302 stmt.Fields.Fields = append(stmt.Fields.Fields[:currColumnIndex], stmt.Fields.Fields[currColumnIndex+1:]...) 303 p.groupByColumn[i] = index 304 deleteNum++ 305 } 306 } 307 308 for i := 0; i < len(p.orderByColumn); i++ { 309 p.orderByColumn[i] -= deleteNum 310 currColumnIndex := p.originColumnCount + len(p.groupByColumn) + i - deleteNum 311 field, isColumnExpr := stmt.Fields.Fields[currColumnIndex].Expr.(*ast.ColumnNameExpr) 312 if !isColumnExpr { 313 continue 314 } 315 if index, ok := selectFields[field.Name.Name.L]; !ok { 316 continue 317 } else { 318 stmt.Fields.Fields = append(stmt.Fields.Fields[:currColumnIndex], stmt.Fields.Fields[currColumnIndex+1:]...) 319 p.orderByColumn[i] = index 320 deleteNum++ 321 } 322 } 323 } 324 325 func createSelectFieldsFromByItems(p *SelectPlan, items []*ast.ByItem) ([]*ast.SelectField, error) { 326 var ret []*ast.SelectField 327 for _, item := range items { 328 selectField, err := createSelectFieldFromByItem(p, item) 329 if err != nil { 330 return nil, err 331 } 332 ret = append(ret, selectField) 333 } 334 return ret, nil 335 } 336 337 func createSelectFieldFromByItem(p *SelectPlan, item *ast.ByItem) (*ast.SelectField, error) { 338 // 特殊处理DATABASE()这种情况 339 if funcExpr, ok := item.Expr.(*ast.FuncCallExpr); ok { 340 if funcExpr.FnName.L == "database" { 341 ret := &ast.SelectField{ 342 Expr: item.Expr, 343 } 344 return ret, nil 345 } 346 return nil, fmt.Errorf("ByItem.Expr is a FuncCallExpr but not DATABASE()") 347 } 348 349 columnExpr, ok := item.Expr.(*ast.ColumnNameExpr) 350 if !ok { 351 return nil, fmt.Errorf("ByItem.Expr is not a ColumnNameExpr") 352 } 353 354 rule, need, isAlias, err := NeedCreateColumnNameExprDecoratorInField(p.TableAliasStmtInfo, columnExpr) 355 if err != nil { 356 return nil, err 357 } 358 359 if need { 360 decorator := CreateColumnNameExprDecorator(columnExpr, rule, isAlias, p.GetRouteResult()) 361 item.Expr = decorator 362 } 363 364 ret := &ast.SelectField{ 365 Expr: item.Expr, 366 } 367 return ret, nil 368 } 369 370 // 处理from table和join on部分 371 // 主要是改写table ExprNode, 并找到路由条件 372 func handleTableRefs(p *SelectPlan, stmt *ast.SelectStmt) error { 373 tableRefs := stmt.From 374 if tableRefs == nil { 375 return nil 376 } 377 378 join := tableRefs.TableRefs 379 if join == nil { 380 return nil 381 } 382 383 return handleJoin(p.TableAliasStmtInfo, join) 384 } 385 386 func handleJoin(p *TableAliasStmtInfo, join *ast.Join) error { 387 if err := precheckJoinClause(join); err != nil { 388 return fmt.Errorf("precheck Join error: %v", err) 389 } 390 391 // 只允许最多两个表的JOIN 392 if join.Left != nil { 393 switch left := join.Left.(type) { 394 case *ast.TableSource: 395 // 改写两个表的node 396 if err := rewriteTableSource(p, left); err != nil { 397 return fmt.Errorf("rewrite left TableSource error: %v", err) 398 } 399 case *ast.Join: 400 if err := handleJoin(p, left); err != nil { 401 return fmt.Errorf("handle nested left Join error: %v", err) 402 } 403 default: 404 return fmt.Errorf("invalid left Join type: %T", join.Left) 405 } 406 } 407 if join.Right != nil { 408 right, ok := join.Right.(*ast.TableSource) 409 if !ok { 410 return fmt.Errorf("right is not TableSource, type: %T", join.Right) 411 } 412 413 if err := rewriteTableSource(p, right); err != nil { 414 return fmt.Errorf("rewrite right TableSource error: %v", err) 415 } 416 } 417 418 // 改写ON条件 419 if join.On != nil { 420 err := rewriteOnCondition(p, join.On) 421 if err != nil { 422 return fmt.Errorf("rewrite on condition error: %v", err) 423 } 424 } 425 426 return nil 427 } 428 429 func handleWhere(p *SelectPlan, stmt *ast.SelectStmt) (err error) { 430 if stmt.Where == nil { 431 return nil 432 } 433 434 has, result, decorator, err := handleComparisonExpr(p.TableAliasStmtInfo, stmt.Where) 435 if err != nil { 436 return fmt.Errorf("rewrite Where error: %v", err) 437 } 438 if has { 439 p.GetRouteResult().Inter(result) 440 } 441 stmt.Where = decorator 442 return nil 443 } 444 445 // 检查TableRefs中存在的不允许在分表中执行的语法 446 func precheckJoinClause(join *ast.Join) error { 447 // 不允许USING的列名中出现DB名和表名, 因为目前Join子句的TableName不方便加装饰器 448 for _, c := range join.Using { 449 if c.Schema.String() != "" { 450 return fmt.Errorf("JOIN does not support USING column with schema") 451 } 452 if c.Table.String() != "" { 453 return fmt.Errorf("JOIN does not support USING column with table") 454 } 455 } 456 return nil 457 } 458 459 // 改写TableSource节点, 得到一个装饰器 460 // Source必须为TableName节点或子查询 461 func rewriteTableSource(p *TableAliasStmtInfo, tableSource *ast.TableSource) error { 462 switch ss := tableSource.Source.(type) { 463 case *ast.TableName: 464 return rewriteTableNameInTableSource(p, tableSource) 465 case *ast.SelectStmt: 466 if err := handleSubquerySelectStmt(p, ss); err != nil { 467 return fmt.Errorf("handleSubquerySelectStmt error: %v", err) 468 } 469 alias := tableSource.AsName.L 470 if alias != "" { 471 if _, err := p.RecordSubqueryTableAlias(alias); err != nil { 472 return fmt.Errorf("record subquery alias error: %v", err) 473 } 474 } 475 return nil 476 default: 477 return fmt.Errorf("field Source cannot handle, type: %T", tableSource.Source) 478 } 479 } 480 481 func rewriteTableNameInTableSource(p *TableAliasStmtInfo, tableSource *ast.TableSource) error { 482 tableName, ok := tableSource.Source.(*ast.TableName) 483 if !ok { 484 return fmt.Errorf("field Source is not type of TableName, type: %T", tableSource.Source) 485 } 486 alias := tableSource.AsName.L 487 488 rule, need, err := NeedCreateTableNameDecorator(p, tableName, alias) 489 if err != nil { 490 return fmt.Errorf("check NeedCreateTableNameDecorator error: %v", err) 491 } 492 493 if !need { 494 return nil 495 } 496 497 // 这是一个分片表或关联表, 创建一个TableName的装饰器, 并替换原有节点 498 d, err := CreateTableNameDecorator(tableName, rule, p.GetRouteResult()) 499 if err != nil { 500 return fmt.Errorf("create TableNameDecorator error: %v", err) 501 } 502 tableSource.Source = d 503 return nil 504 } 505 506 func rewriteOnCondition(p *TableAliasStmtInfo, on *ast.OnCondition) error { 507 has, result, decorator, err := handleComparisonExpr(p, on.Expr) 508 if err != nil { 509 return fmt.Errorf("rewrite Expr in OnCondition error: %v", err) 510 } 511 if has { 512 p.GetRouteResult().Inter(result) 513 } 514 on.Expr = decorator 515 return nil 516 } 517 518 // 处理info中的hint 519 // 目前只有mycat路由方式支持 520 // hint路由会覆盖遍历语法树时计算出的路由 521 func postHandleHintDatabaseFunction(p *SelectPlan) error { 522 if p.hintPhyDB == "" { 523 return nil 524 } 525 526 rule, ok := p.router.GetShardRule(p.result.db, p.result.table) 527 if !ok { 528 return fmt.Errorf("sharding rule of route result not found, result: %v", p.result) 529 } 530 mr, ok := rule.(router.MycatRule) 531 if !ok { 532 return fmt.Errorf("sharding rule is not mycat mode, result: %v", p.result) 533 } 534 535 if !router.IsMycatShardingRule(mr.GetType()) { // TODO: need refactor, why is MycatRule's type not mycat rule? 536 return fmt.Errorf("only mycat rule supports database function hint") 537 } 538 539 idx, ok := mr.GetTableIndexByDatabaseName(p.hintPhyDB) 540 if !ok { 541 return fmt.Errorf("hint db not found: %s", p.hintPhyDB) 542 } 543 544 p.result.indexes = []int{idx} 545 return nil 546 } 547 548 // ColumnNameRewriteVisitor visit ColumnNameExpr, check if need decorate, and then decorate it. 549 type ColumnNameRewriteVisitor struct { 550 info *TableAliasStmtInfo 551 } 552 553 // NewColumnNameRewriteVisitor constructor of ColumnNameRewriteVisitor 554 func NewColumnNameRewriteVisitor(p *TableAliasStmtInfo) *ColumnNameRewriteVisitor { 555 return &ColumnNameRewriteVisitor{ 556 info: p, 557 } 558 } 559 560 // Enter implement ast.Visitor 561 func (s *ColumnNameRewriteVisitor) Enter(n ast.Node) (node ast.Node, skipChildren bool) { 562 return n, false 563 } 564 565 // Leave implement ast.Visitor 566 func (s *ColumnNameRewriteVisitor) Leave(n ast.Node) (node ast.Node, ok bool) { 567 field, ok := n.(*ast.ColumnNameExpr) 568 if !ok { 569 return n, true 570 } 571 rule, need, isAlias, err := NeedCreateColumnNameExprDecoratorInField(s.info, field) 572 if err != nil { 573 panic(fmt.Errorf("check NeedCreateColumnNameExprDecoratorInField in ColumnNameExpr error: %v", err)) 574 } 575 if need { 576 decorator := CreateColumnNameExprDecorator(field, rule, isAlias, s.info.GetRouteResult()) 577 return decorator, true 578 } 579 580 return n, true 581 } 582 583 func handleFieldList(p *SelectPlan, stmt *ast.SelectStmt) (err error) { 584 defer func() { 585 if e := recover(); e != nil { 586 err = fmt.Errorf("handleFieldList panic: %v", e) 587 } 588 }() 589 590 fields := stmt.Fields 591 if fields == nil { 592 return nil 593 } 594 595 // 先用一个Visitor生成一个替换表名的装饰器 596 // 这里如果出错, 只能通过panic返回err 597 columnNameRewriter := NewColumnNameRewriteVisitor(p.TableAliasStmtInfo) 598 fields.Accept(columnNameRewriter) 599 600 // 如果最外层是聚合函数, 则生成一个聚合函数装饰器, 并记录对应的列位置 601 // 只处理最外层的聚合函数. 602 for i, f := range fields.Fields { 603 switch field := f.Expr.(type) { 604 case *ast.AggregateFuncExpr: 605 merger, err := CreateAggregateFunctionMerger(field.F, i) 606 if err != nil { 607 return fmt.Errorf("create aggregate function merger error, column index: %d, err: %v", i, err) 608 } 609 if err := p.setAggregateFuncMerger(i, merger); err != nil { 610 return fmt.Errorf("set aggregate function merger error, column index: %d, err: %v", i, err) 611 } 612 default: 613 // do nothing 614 } 615 } 616 return nil 617 } 618 619 func handleHaving(p *SelectPlan, stmt *ast.SelectStmt) (err error) { 620 defer func() { 621 if e := recover(); e != nil { 622 err = fmt.Errorf("handleHaving panic: %v", e) 623 } 624 }() 625 626 having := stmt.Having 627 if having == nil { 628 return nil 629 } 630 631 // 先用一个Visitor生成一个替换表名的装饰器 632 // 这里如果出错, 只能通过panic返回err 633 columnNameRewriter := NewColumnNameRewriteVisitor(p.TableAliasStmtInfo) 634 having.Accept(columnNameRewriter) 635 return nil 636 } 637 638 func handleComparisonExpr(p *TableAliasStmtInfo, comp ast.ExprNode) (bool, []int, ast.ExprNode, error) { 639 switch expr := comp.(type) { 640 case *ast.BinaryOperationExpr: 641 return handleBinaryOperationExpr(p, expr) 642 case *ast.PatternInExpr: 643 return handlePatternInExpr(p, expr) 644 case *ast.BetweenExpr: 645 return handleBetweenExpr(p, expr) 646 case *ast.ParenthesesExpr: 647 has, routeResult, newExpr, err := handleComparisonExpr(p, expr.Expr) 648 expr.Expr = newExpr 649 return has, routeResult, expr, err 650 default: 651 // 其他情况只替换表名 (但是不处理根节点是ColumnNameExpr的情况, 理论上也不会出现这种情况) 652 columnNameRewriter := NewColumnNameRewriteVisitor(p) 653 expr.Accept(columnNameRewriter) 654 return false, p.GetRouteResult().GetShardIndexes(), comp, nil 655 } 656 } 657 658 func handlePatternInExpr(p *TableAliasStmtInfo, expr *ast.PatternInExpr) (bool, []int, ast.ExprNode, error) { 659 rule, need, isAlias, err := NeedCreatePatternInExprDecorator(p, expr) 660 if err != nil { 661 return false, nil, nil, fmt.Errorf("check PatternInExpr error: %v", err) 662 } 663 if !need { 664 return false, nil, expr, nil 665 } 666 decorator, err := CreatePatternInExprDecorator(expr, rule, isAlias, p.GetRouteResult()) 667 if err != nil { 668 return false, nil, nil, fmt.Errorf("create PatternInExprDecorator error: %v", err) 669 } 670 return true, decorator.GetCurrentRouteResult(), decorator, nil 671 } 672 673 func handleBetweenExpr(p *TableAliasStmtInfo, expr *ast.BetweenExpr) (bool, []int, ast.ExprNode, error) { 674 rule, need, isAlias, err := NeedCreateBetweenExprDecorator(p, expr) 675 if err != nil { 676 return false, nil, nil, fmt.Errorf("check BetweenExpr error: %v", err) 677 } 678 if !need { 679 return false, nil, expr, nil 680 } 681 682 decorator, err := CreateBetweenExprDecorator(expr, rule, isAlias, p.GetRouteResult()) 683 if err != nil { 684 return false, nil, nil, fmt.Errorf("create CreateBetweenExprDecorator error: %v", err) 685 } 686 687 return true, decorator.GetCurrentRouteResult(), decorator, nil 688 } 689 690 // return value: hasRoutingResult, RouteResult, Decorator, error 691 // the Decorator must not be nil. If no modification to the input expr, just return it. 692 func handleBinaryOperationExpr(p *TableAliasStmtInfo, expr *ast.BinaryOperationExpr) (bool, []int, ast.ExprNode, error) { 693 _, ok := opcode.Ops[expr.Op] 694 if !ok { 695 return false, nil, nil, fmt.Errorf("unknown BinaryOperationExpr.Op: %v", expr.Op) 696 } 697 698 switch expr.Op { 699 case opcode.LogicAnd, opcode.LogicOr: 700 return handleBinaryOperationExprLogic(p, expr) 701 case opcode.EQ, opcode.NE, opcode.GT, opcode.GE, opcode.LT, opcode.LE: 702 return handleBinaryOperationExprMathCompare(p, expr) 703 default: 704 return handleBinaryOperationExprOther(p, expr) 705 } 706 } 707 708 // 处理逻辑比较运算 709 func handleBinaryOperationExprLogic(p *TableAliasStmtInfo, expr *ast.BinaryOperationExpr) (bool, []int, ast.ExprNode, error) { 710 lHas, lResult, lDecorator, lErr := handleComparisonExpr(p, expr.L) 711 if lErr != nil { 712 return false, nil, nil, fmt.Errorf("handle BinaryOperationExpr.L error: %v", lErr) 713 } 714 rHas, rResult, rDecorator, rErr := handleComparisonExpr(p, expr.R) 715 if rErr != nil { 716 return false, nil, nil, fmt.Errorf("handle BinaryOperationExpr.R error: %v", rErr) 717 } 718 719 if lDecorator != nil { 720 expr.L = lDecorator 721 } 722 if rDecorator != nil { 723 expr.R = rDecorator 724 } 725 726 has, result := mergeBinaryOperationRouteResult(expr.Op, lHas, lResult, rHas, rResult) 727 return has, result, expr, nil 728 } 729 730 // 处理算术比较运算 731 // 如果出现列名, 则必须为列名与列名比较, 列名与值比较, 否则会报错 (比如 id + 2 = 3 就会报错, 因为 id + 2 处理不了) 732 // 如果是其他情况, 则直接返回 (如 1 = 1 这种) 733 func handleBinaryOperationExprMathCompare(p *TableAliasStmtInfo, expr *ast.BinaryOperationExpr) (bool, []int, ast.ExprNode, error) { 734 lType := getExprNodeTypeInBinaryOperation(expr.L) 735 rType := getExprNodeTypeInBinaryOperation(expr.R) 736 737 // handle hint database function: SELECT * from tbl where DATABASE() = db_0 / 'db_0' / `db_0` 738 if expr.Op == opcode.EQ { 739 if lType == FuncCallExpr { 740 hintDB, err := getDatabaseFuncHint(expr.L.(*ast.FuncCallExpr), expr.R) 741 if err != nil { 742 return false, nil, nil, fmt.Errorf("get database function hint error: %v", err) 743 } 744 if hintDB != "" { 745 p.hintPhyDB = hintDB 746 return false, nil, expr, nil 747 } 748 } else if rType == FuncCallExpr { 749 hintDB, err := getDatabaseFuncHint(expr.R.(*ast.FuncCallExpr), expr.L) 750 if err != nil { 751 return false, nil, nil, fmt.Errorf("get database function hint error: %v", err) 752 } 753 if hintDB != "" { 754 p.hintPhyDB = hintDB 755 return false, nil, expr, nil 756 } 757 } 758 } 759 760 if lType == ColumnNameExpr && rType == ColumnNameExpr { 761 return handleBinaryOperationExprCompareLeftColumnRightColumn(p, expr) 762 } 763 764 if lType == ColumnNameExpr { 765 if rType == ValueExpr { 766 return handleBinaryOperationExprCompareLeftColumnRightValue(p, expr, getFindTableIndexesFunc(expr.Op)) 767 } 768 column := expr.L.(*ast.ColumnNameExpr) 769 rule, need, isAlias, err := NeedCreateColumnNameExprDecoratorInCondition(p, column) 770 if err != nil { 771 return false, nil, nil, fmt.Errorf("check ColumnNameExpr error in BinaryOperationExpr.L: %v", err) 772 } 773 if !need { 774 return false, nil, expr, nil 775 } 776 777 decorator := CreateColumnNameExprDecorator(column, rule, isAlias, p.GetRouteResult()) 778 expr.L = decorator 779 return false, nil, expr, nil 780 } 781 782 if rType == ColumnNameExpr { 783 if lType == ValueExpr { 784 return handleBinaryOperationExprCompareLeftValueRightColumn(p, expr, getFindTableIndexesFunc(inverseOperator(expr.Op))) 785 } 786 column := expr.R.(*ast.ColumnNameExpr) 787 rule, need, isAlias, err := NeedCreateColumnNameExprDecoratorInCondition(p, column) 788 if err != nil { 789 return false, nil, nil, fmt.Errorf("check ColumnNameExpr error in BinaryOperationExpr.R: %v", err) 790 } 791 if !need { 792 return false, nil, expr, nil 793 } 794 795 decorator := CreateColumnNameExprDecorator(column, rule, isAlias, p.GetRouteResult()) 796 expr.R = decorator 797 return false, nil, expr, nil 798 } 799 800 return false, nil, expr, nil 801 } 802 803 // 处理其他情况的运算 804 // 如果出现分表列, 只创建一个替换表名的装饰器, 不计算路由. 因此返回结果前两个一定是false, nil 805 func handleBinaryOperationExprOther(p *TableAliasStmtInfo, expr *ast.BinaryOperationExpr) (bool, []int, ast.ExprNode, error) { 806 if lColumn, ok := expr.L.(*ast.ColumnNameExpr); ok { 807 lRule, lNeed, lIsAlias, lErr := NeedCreateColumnNameExprDecoratorInCondition(p, lColumn) 808 if lErr != nil { 809 return false, nil, nil, fmt.Errorf("check ColumnNameExpr error in BinaryOperationExpr.L: %v", lErr) 810 } 811 812 if lNeed { 813 lDecorator := CreateColumnNameExprDecorator(lColumn, lRule, lIsAlias, p.GetRouteResult()) 814 expr.L = lDecorator 815 } 816 } 817 if rColumn, ok := expr.R.(*ast.ColumnNameExpr); ok { 818 rRule, rNeed, rIsAlias, rErr := NeedCreateColumnNameExprDecoratorInCondition(p, rColumn) 819 if rErr != nil { 820 return false, nil, nil, fmt.Errorf("check ColumnNameExpr error in BinaryOperationExpr.R: %v", rErr) 821 } 822 if rNeed { 823 rDecorator := CreateColumnNameExprDecorator(rColumn, rRule, rIsAlias, p.GetRouteResult()) 824 expr.R = rDecorator 825 } 826 } 827 return false, nil, expr, nil 828 } 829 830 // 获取mycat路由模式下的hint物理DB名 831 func getDatabaseFuncHint(f *ast.FuncCallExpr, v ast.ExprNode) (string, error) { 832 if f.FnName.L != "database" { 833 return "", nil 834 } 835 switch vv := v.(type) { 836 case *ast.ColumnNameExpr: 837 return vv.Name.Name.String(), nil 838 case *driver.ValueExpr: 839 return vv.GetString(), nil 840 default: 841 return "", fmt.Errorf("invalid value type of database function hint: %T", v) 842 } 843 } 844 845 // 返回一个根据路由信息和路由值获取路由结果的函数 846 // 左边为列名, 右边为参数 847 func getFindTableIndexesFunc(op opcode.Op) func(rule router.Rule, columnName string, v interface{}) ([]int, error) { 848 findTableIndexesFunc := func(rule router.Rule, columnName string, v interface{}) ([]int, error) { 849 // 如果不是分表列, 则需要返回所有分片 850 if rule.GetShardingColumn() != columnName { 851 return rule.GetSubTableIndexes(), nil 852 } 853 854 // 如果是分表列, 还需要根据运算符判断 855 switch op { 856 case opcode.EQ: 857 index, err := rule.FindTableIndex(v) 858 if err != nil { 859 return nil, err 860 } 861 return []int{index}, nil 862 case opcode.NE: 863 return rule.GetSubTableIndexes(), nil 864 case opcode.GT, opcode.GE, opcode.LT, opcode.LE: 865 // 如果是range路由, 需要做一些特殊处理 866 if rangeShard, ok := rule.GetShard().(router.RangeShard); ok { 867 index, err := rule.FindTableIndex(v) 868 if err != nil { 869 return nil, err 870 } 871 if op == opcode.LT || op == opcode.LE { 872 if op == opcode.LT { 873 index = adjustShardIndex(rangeShard, v, index) 874 } 875 return makeList(rule.GetFirstTableIndex(), index+1), nil 876 } else { 877 return makeList(index, rule.GetLastTableIndex()+1), nil 878 } 879 } 880 881 // 如果不是 (即hash路由), 则返回所有分片 882 return rule.GetSubTableIndexes(), nil 883 default: // should not going here 884 return rule.GetSubTableIndexes(), nil 885 } 886 } 887 888 return findTableIndexesFunc 889 } 890 891 // copy from PlanBuilder.adjustShardIndex() 892 func adjustShardIndex(s router.RangeShard, value interface{}, index int) int { 893 if s.EqualStart(value, index) { 894 return index - 1 895 } 896 return index 897 } 898 899 func inverseOperator(op opcode.Op) opcode.Op { 900 switch op { 901 case opcode.GT: 902 return opcode.LT 903 case opcode.GE: 904 return opcode.LE 905 case opcode.LT: 906 return opcode.GT 907 case opcode.LE: 908 return opcode.GE 909 default: 910 return op 911 } 912 } 913 914 func handleBinaryOperationExprCompareLeftColumnRightColumn(p *TableAliasStmtInfo, expr *ast.BinaryOperationExpr) (bool, []int, ast.ExprNode, error) { 915 lColumn := expr.L.(*ast.ColumnNameExpr) 916 lRule, lNeed, lIsAlias, lErr := NeedCreateColumnNameExprDecoratorInCondition(p, lColumn) 917 if lErr != nil { 918 return false, nil, nil, fmt.Errorf("check ColumnNameExpr error in BinaryOperationExpr.L: %v", lErr) 919 } 920 rColumn := expr.R.(*ast.ColumnNameExpr) 921 rRule, rNeed, rIsAlias, rErr := NeedCreateColumnNameExprDecoratorInCondition(p, rColumn) 922 if rErr != nil { 923 return false, nil, nil, fmt.Errorf("check ColumnNameExpr error in BinaryOperationExpr.R: %v", rErr) 924 } 925 926 if lNeed { 927 lDecorator := CreateColumnNameExprDecorator(lColumn, lRule, lIsAlias, p.GetRouteResult()) 928 expr.L = lDecorator 929 } 930 if rNeed { 931 rDecorator := CreateColumnNameExprDecorator(rColumn, rRule, rIsAlias, p.GetRouteResult()) 932 expr.R = rDecorator 933 } 934 return false, nil, expr, nil 935 } 936 937 func handleBinaryOperationExprCompareLeftColumnRightValue(p *TableAliasStmtInfo, expr *ast.BinaryOperationExpr, findTableIndexes func(router.Rule, string, interface{}) ([]int, error)) (bool, []int, ast.ExprNode, error) { 938 column := expr.L.(*ast.ColumnNameExpr) 939 rule, need, isAlias, err := NeedCreateColumnNameExprDecoratorInCondition(p, column) 940 if err != nil { 941 return false, nil, nil, fmt.Errorf("check ColumnNameExpr error in BinaryOperationExpr.L: %v", err) 942 } 943 if !need { 944 return false, nil, expr, nil 945 } 946 947 decorator := CreateColumnNameExprDecorator(column, rule, isAlias, p.GetRouteResult()) 948 expr.L = decorator 949 950 if rule.GetType() == router.GlobalTableRuleType { 951 return false, nil, expr, nil 952 } 953 954 valueExpr := expr.R.(*driver.ValueExpr) 955 v, err := util.GetValueExprResult(valueExpr) 956 if err != nil { 957 return false, nil, nil, fmt.Errorf("get ValueExpr value error: %v", err) 958 } 959 960 tableIndexes, err := findTableIndexes(rule, column.Name.Name.L, v) 961 if err != nil { 962 return false, nil, nil, fmt.Errorf("find table index error: %v", err) 963 } 964 965 return true, tableIndexes, expr, nil 966 } 967 968 func handleBinaryOperationExprCompareLeftValueRightColumn(p *TableAliasStmtInfo, expr *ast.BinaryOperationExpr, findTableIndexes func(router.Rule, string, interface{}) ([]int, error)) (bool, []int, ast.ExprNode, error) { 969 column := expr.R.(*ast.ColumnNameExpr) 970 rule, need, isAlias, err := NeedCreateColumnNameExprDecoratorInCondition(p, column) 971 if err != nil { 972 return false, nil, nil, fmt.Errorf("check ColumnNameExpr error in BinaryOperationExpr.R: %v", err) 973 } 974 if !need { 975 return false, nil, expr, nil 976 } 977 978 decorator := CreateColumnNameExprDecorator(column, rule, isAlias, p.GetRouteResult()) 979 expr.R = decorator 980 981 if rule.GetType() == router.GlobalTableRuleType { 982 return false, nil, expr, nil 983 } 984 985 valueExpr := expr.L.(*driver.ValueExpr) 986 v, err := util.GetValueExprResult(valueExpr) 987 if err != nil { 988 return false, nil, nil, fmt.Errorf("get ValueExpr value error: %v", err) 989 } 990 991 tableIndexes, err := findTableIndexes(rule, column.Name.Name.L, v) 992 if err != nil { 993 return false, nil, nil, fmt.Errorf("find table index error: %v", err) 994 } 995 996 return true, tableIndexes, expr, nil 997 } 998 999 func mergeBinaryOperationRouteResult(op opcode.Op, lHas bool, lResult []int, rHas bool, rResult []int) (bool, []int) { 1000 switch op { 1001 case opcode.LogicAnd: 1002 if lHas == false && rHas == false { 1003 return false, nil 1004 } 1005 if lHas && rHas { 1006 return true, interList(lResult, rResult) 1007 } 1008 if lHas { 1009 return true, lResult 1010 } 1011 if rHas { 1012 return true, rResult 1013 } 1014 case opcode.LogicOr: 1015 if lHas && rHas { 1016 return true, unionList(lResult, rResult) 1017 } 1018 return false, nil 1019 } 1020 return false, nil 1021 } 1022 1023 func handleLimit(p *SelectPlan, stmt *ast.SelectStmt) error { 1024 need, originOffset, originCount, newLimit := NeedRewriteLimitOrCreateRewrite(stmt) 1025 p.offset = originOffset 1026 p.count = originCount 1027 if need { 1028 stmt.Limit = newLimit 1029 } 1030 return nil 1031 } 1032 1033 func getTableInfoFromTableName(t *ast.TableName) (string, string) { 1034 return t.Schema.O, t.Name.L 1035 } 1036 1037 func getColumnInfoFromColumnName(t *ast.ColumnName) (string, string, string) { 1038 return t.Schema.O, t.Table.L, t.Name.L 1039 }